diff --git a/.gitattributes b/.gitattributes index 5f2263daae1414ee1d3642c5d63702870c770711..93e8d3139da7635730826f24ab0ddf6c4bf1b2e4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -17,3 +17,4 @@ build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter= build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch29-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h new file mode 100644 index 0000000000000000000000000000000000000000..914bbddda9227d1f1772d8e8171b06280b7a5f61 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h @@ -0,0 +1,606 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/thread/reduction_operators.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (i.e. number of outer ranks) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +struct TensorReductionAffineContiguousParams { + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + Coord extent; /// Extent of source tensor + FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank + int64_t dst_stride[kReducedRank]; /// stride (units of bytes) - I, J + int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K + int64_t workspace_stride; /// stride (units of bytes) between workspace + int workspace_count; /// number of workspaces + + uint64_t inner_count; /// Number of elements in reduced index space + uint64_t outer_count; /// Number of elements in outer index space + + ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank + ElementSource const * source; /// Pointer to source pointer of rank kRank + ReductionOp reduction_op; /// Reduction operator + ElementCompute reduction_identity; /// Identity element used by reduction operator + ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + TensorReductionAffineContiguousParams() { + + } + + /// Ctor + TensorReductionAffineContiguousParams( + Coord extent_, ///< Extent of source tensor + ElementOutput * dst_ptr_, ///< Output tensor data + int64_t dst_stride_[], ///< Stride (units of elements) + ElementSource const * src_ptr_, ///< Source tensor data + int64_t src_stride_[], ///< Stride (units of elements) + ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions + int64_t workspace_stride_, ///< Stride between workspaces + int workspace_count_, ///< Number of workspaces + ReductionOp reduction_op_, ///< Reduction operator + ElementCompute reduction_identity_ = ElementCompute() ///< Identity element used by reduction operator + ): + extent(extent_), + inner_count(1), + outer_count(1), + destination(dst_ptr_), + source(src_ptr_), + device_workspace(device_workspace_), + workspace_stride(workspace_stride_), + workspace_count(workspace_count_), + reduction_op(reduction_op_), + reduction_identity(reduction_identity_) { + + // Initialize divisors for fast div-mod + for (int p = 1; p < kRank; ++p) { + divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); + } + + int input_size_bits = sizeof_bits::value; + int output_size_bits = sizeof_bits::value; + + // Compute strides in units of bytes + for (int p = 0; p < kReducedRank; ++p) { + dst_stride[p] = dst_stride_[p] * output_size_bits / 8; + } + + for (int p = 0; p < kRank - 1; ++p) { + src_stride[p] = src_stride_[p] * input_size_bits / 8; + } + + // Compute number of elements in strided ranks + for (int p = 0; p < kReducedRank; ++p) { + outer_count *= uint64_t(extent[p]); + } + + for (int p = 0; p < kInnerRank; ++p) { + inner_count *= uint64_t(extent[kRank - 1 - p]); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to reduce a tensor with affine layout over a set of ranks *INCLUDING* the contiguous +/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineContiguous { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + using ComputeFragment = Array; + using SourceFragment = AlignedArray; + using OutputFragment = AlignedArray; + + /// Shared memory allocation used for reduction within the CTA + struct SharedStorage { + Array workspace; + }; + + /// Parameters structure + using Params = TensorReductionAffineContiguousParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_inner_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose into a coordinate of rank + coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kRank - kInnerRank]); + + // Compute an offset using the souce stride + src_offset = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kInnerRank - 1; ++i) { + src_offset += coord[i] * params.src_stride[kReducedRank + i]; + } + src_offset += coord[kInnerRank - 1] * sizeof_bits::value / 8; + } + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose into coordinate of rank + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute offsets using destination and source strides + dst_offset = 0; + src_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + src_offset += params.src_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices yielding a single element + CUTLASS_DEVICE + ElementCompute reduce_indices_( + Params const ¶ms, + ElementCompute *threadblock_workspace, + char const *src_byte_ptr, + int coord_c) { + + NumericArrayConverter convert_source; + ReductionOp reduction_op(params.reduction_op); + + // + // Early exit or initialize to identity element + // + if (!params.inner_count) { + return params.reduction_identity; + } + + ComputeFragment accumulator; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(accumulator.size()); ++i) { + accumulator[i] = params.reduction_identity; + } + + // Compute the coordinate of the first access + int64_t src_byte_offset = 0; + Coord coord; + + uint64_t linear_idx = (threadIdx.x + blockDim.x * threadIdx.z + blockDim.x * blockIdx.z * blockDim.z) * kVectorLength; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + + // Load the first vector + SourceFragment source_fragment[kBatchSize]; + + bool not_done = true; + + // Iterate over vectors in a linearized reduction index space + while (not_done) { + + bool guards[kBatchSize]; + + // Issue a batch of loads + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + + if (linear_idx < params.inner_count) { + source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); + guards[b] = true; + } + else { + guards[b] = false; + not_done = false; + } + + linear_idx += (blockDim.z * gridDim.z * blockDim.x) * kVectorLength; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + } + + // Perform a batch of reduction operations + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + if (guards[b]) { + auto cvt = convert_source(source_fragment[b]); + + accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( + reduction_op, + accumulator, + cvt); + } + } + }; + + // + // Reduction of vectors to scalar + // + + ElementCompute reduced_accumulator = accumulator[0]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kVectorLength; ++i) { + reduced_accumulator = reduction_op(reduced_accumulator, accumulator[i]); + } + + // + // Reduction within CTA across threadIdx.xz => threadIdx{.x = 0, .z = 0} + // + // This re-arranges data so threadIdx.y is effectively a row index and threadIdx.xz is a column + // + + int thread_count = blockDim.x * blockDim.z; + int thread_j = threadIdx.x + blockDim.x * threadIdx.z; + int thread_i = threadIdx.y; + + ElementCompute *frag_ptr = reinterpret_cast(threadblock_workspace) + thread_i * thread_count; + + frag_ptr[thread_j] = reduced_accumulator; + + // + // Reduce + // + CUTLASS_PRAGMA_NO_UNROLL + while (thread_count > 1) { + thread_count /= 2; + + __syncthreads(); + + if (thread_j < thread_count) { + ElementCompute other = frag_ptr[thread_j + thread_count]; + + reduced_accumulator = reduction_op(reduced_accumulator, other); + + frag_ptr[thread_j] = reduced_accumulator; + } + + __syncthreads(); + } + + + return reduced_accumulator; + } + +public: + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; + + char const * src_byte_ptr = reinterpret_cast(params.source); + char * dst_byte_ptr = nullptr; + + // If performing a reduction across CTAs, redirect output to device workspace + if (gridDim.z == 1) { + dst_byte_ptr = reinterpret_cast(params.destination); + } + else { + dst_byte_ptr = reinterpret_cast(params.device_workspace); + } + + uint64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + int64_t src_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + if (gridDim.z == 1) { + + /// Complete the reduction with no workspace + while (idx_linear < params.outer_count) { + + ElementCompute result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset, + coord_c); + + // Store the result after possible final reduction within the CTA + if (threadIdx.z == 0 && threadIdx.x == 0) { + + // Convert to output type and store + NumericConverter convert_output; + ElementOutput cvt = convert_output(result); + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = cvt; + } + + __syncthreads(); + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + } // while + } + else { + + /// Complete the reduction with workspace + while (idx_linear < params.outer_count) { + + ElementCompute result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset, + coord_c); + + int64_t byte_offset = + blockIdx.z * params.workspace_stride + idx_linear * sizeof_bits::value / 8; + + // Store the result for final reduction + if (threadIdx.z == 0 && threadIdx.x == 0) { + *reinterpret_cast(dst_byte_ptr + byte_offset) = result; + } + + __syncthreads(); + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + } // while + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to perform final reduction +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineContiguousFinal { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + /// Shared memory + struct SharedStorage { }; + + /// Parameters structure + using Params = TensorReductionAffineContiguousParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + uint64_t linear_idx) const { + + // Decompose into coordinate of rank + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute offsets using destination and source strides + dst_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices + CUTLASS_DEVICE + ElementCompute reduce_indices_( + Params const ¶ms, + ElementCompute const *device_workspace) { + + ReductionOp reduction_op(params.reduction_op); + char const *src_byte_ptr = reinterpret_cast(device_workspace); + + // Accumulated output + ElementCompute accumulator = params.reduction_identity; + + for (int iter = 0; iter < params.workspace_count; ++iter) { + ElementCompute workspace_item = *reinterpret_cast(src_byte_ptr); + + accumulator = reduction_op(accumulator, workspace_item); + + src_byte_ptr += params.workspace_stride; + } + + return accumulator; + } + +public: + + // + // Methods + // + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + uint64_t idx_linear = blockIdx.x * blockDim.x + threadIdx.x; + + char * dst_byte_ptr = reinterpret_cast(params.destination); + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + + /// Complete the reduction + while (idx_linear < params.outer_count) { + + ElementCompute result = reduce_indices_(params, params.device_workspace + idx_linear); + + // Convert to output type and store + NumericConverter convert_output; + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = convert_output(result); + + // Update indices and pointers + idx_linear += gridDim.x * blockDim.x; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h new file mode 100644 index 0000000000000000000000000000000000000000..0538184f3886b53207cc28a46a9fb8b04d3e8c5e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h @@ -0,0 +1,641 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/thread/reduction_operators.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Parameters structure +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +struct TensorReductionAffineStridedParams { + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + Coord extent; /// Extent of source tensor + FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank + int64_t dst_stride[kReducedRank - 1]; /// stride (units of bytes) - I, J + int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K + int64_t workspace_stride; /// stride (units of bytes) between workspace + int64_t workspace_outer_stride; /// stride (units of bytes) between 'rows' of the workspace + int workspace_count; /// number of workspaces + + uint64_t inner_count; /// Number of elements in reduced index space + uint64_t outer_count; /// Number of elements in outer index space + + ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank + ElementSource const * source; /// Pointer to source pointer of rank kRank + ReductionOp reduction_op; /// Reduction operator + ElementCompute reduction_identity; /// Identity element for reduction operator + ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + TensorReductionAffineStridedParams() { + + } + + /// Ctor + TensorReductionAffineStridedParams( + Coord extent_, ///< Extent of source tensor + ElementOutput * dst_ptr_, ///< Output tensor data + int64_t dst_stride_[], ///< Stride (units of elements) + ElementSource const * src_ptr_, ///< Source tensor data + int64_t src_stride_[], ///< Stride (units of elements) + ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions + int64_t workspace_stride_, ///< Stride between workspaces + int workspace_count_, ///< Number of workspaces + ReductionOp reduction_op_, ///< Reduction operator + ElementCompute reduction_identity_ = ElementCompute() ///< Identity element for reduction operator + ): + extent(extent_), + inner_count(1), + outer_count(1), + destination(dst_ptr_), + source(src_ptr_), + device_workspace(device_workspace_), + workspace_outer_stride(0), + workspace_stride(workspace_stride_), + workspace_count(workspace_count_), + reduction_op(reduction_op_), + reduction_identity(reduction_identity_) { + + // Initialize divisors for fast div-mod + for (int p = 1; p < kRank; ++p) { + divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); + } + + int input_size_bits = sizeof_bits::value; + int output_size_bits = sizeof_bits::value; + + workspace_outer_stride = workspace_stride * workspace_count; + + // Compute strides in units of bytes + for (int p = 0; p < kReducedRank - 1; ++p) { + dst_stride[p] = dst_stride_[p] * output_size_bits / 8; + } + + for (int p = 0; p < kRank - 1; ++p) { + src_stride[p] = src_stride_[p] * input_size_bits / 8; + } + + // Compute number of elements in strided ranks + for (int p = 0; p < kReducedRank - 1; ++p) { + outer_count *= uint64_t(extent[p]); + } + + for (int p = 0; p < kInnerRank; ++p) { + inner_count *= uint64_t(extent[kReducedRank + p - 1]); + } + } +}; + +/// Kernel to reduce a tensor with affine layout over a set of ranks *EXCLUDING* the contiguous +/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineStrided { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + using ComputeFragment = Array; + using SourceFragment = AlignedArray; + using OutputFragment = AlignedArray; + + /// Shared memory allocation used for reduction within the CTA + struct SharedStorage { + Array workspace; + }; + + /// Parameters structure + using Params = TensorReductionAffineStridedParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_inner_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose into coordinate + coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kReducedRank - 1]); + + // Compute linear offset + src_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kInnerRank; ++i) { + src_offset += params.src_stride[kReducedRank + i - 1] * coord[i]; + } + } + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose linear coordinate + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute offset into tensors + dst_offset = 0; + src_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank - 1; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + src_offset += params.src_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices + CUTLASS_DEVICE + ComputeFragment reduce_indices_( + Params const ¶ms, + ElementCompute *threadblock_workspace, + char const *src_byte_ptr) { + + NumericArrayConverter convert_source; + ReductionOp reduction_op(params.reduction_op); + + // Accumulated output + ComputeFragment identity_frag; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(identity_frag.size()); ++i) { + identity_frag[i] = params.reduction_identity; + } + + if (!params.inner_count) { + return identity_frag; + } + + ComputeFragment accumulator = identity_frag; + + // Compute the coordinate of the first access + int64_t src_byte_offset = 0; + Coord coord; + + uint64_t linear_idx = threadIdx.z + blockIdx.z * blockDim.z; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + + // Load the first vector + SourceFragment source_fragment[kBatchSize]; + + bool not_done = true; + + // Iterate over vectors in a linearized reduction index space + while (not_done) { + + bool guards[kBatchSize]; + + // Issue a batch of loads + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + + if (linear_idx < params.inner_count) { + source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); + guards[b] = true; + } + else { + guards[b] = false; + not_done = false; + } + + linear_idx += blockDim.z * gridDim.z; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + } + + // Perform a batch of reduction operations + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + if (guards[b]) { + + auto cvt = convert_source(source_fragment[b]); + + accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( + reduction_op, + accumulator, + cvt); + } + } + }; + + // Optional reduction within a CTA + if (blockDim.z > 1) { + + // Linearized thread ID + int thread_idx = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); + + // all threads store to workspace + ComputeFragment *frag_ptr = reinterpret_cast(threadblock_workspace); + + frag_ptr[thread_idx] = accumulator; + + __syncthreads(); + + if (threadIdx.z == 0) { + // Load all additional block indices + for (int z = 1; z < blockDim.z; ++z) { + ComputeFragment frag = frag_ptr[thread_idx + z * blockDim.x * blockDim.y]; + + accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( + reduction_op, + accumulator, + frag); + } + } + + __syncthreads(); + } + + return accumulator; + } + +public: + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; + + char const * src_byte_ptr = reinterpret_cast(params.source + coord_c); + char * dst_byte_ptr = nullptr; + + // If performing a reduction across CTAs, redirect output to device workspace + if (gridDim.z == 1) { + dst_byte_ptr = reinterpret_cast(params.destination + coord_c); + } + else { + dst_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); + } + + // If the C index is out of bounds, exit + if (coord_c >= params.extent[kRank - 1]) { + return; + } + + int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + int64_t src_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + if (gridDim.z == 1) { + + /// Complete the reduction with no workspace + while (idx_linear < params.outer_count) { + + ComputeFragment result; + + result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset); + + // Store the result after possible final reduction within the CTA + if (threadIdx.z == 0) { + + // Convert to output type and store + NumericArrayConverter convert_output; + auto cvt = convert_output(result); + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = + reinterpret_cast(cvt); + } + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + } // while + } + else { + + /// Complete the reduction with a device workspace + while (idx_linear < params.outer_count) { + + ComputeFragment result; + + result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset); + + // Store the result after possible final reduction within the CTA + if (threadIdx.z == 0) { + + int64_t byte_offset = + blockIdx.z * params.workspace_stride + idx_linear * params.workspace_outer_stride; + + // No conversion - store in compute type + *reinterpret_cast(dst_byte_ptr + byte_offset) = + reinterpret_cast(result); + } + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + } // while (outer index) + } // if () + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to perform final reduction +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineStridedFinal { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + using ComputeFragment = Array; + using SourceFragment = AlignedArray; + using OutputFragment = AlignedArray; + + /// Shared memory + struct SharedStorage { }; + + /// Parameters structure + using Params = TensorReductionAffineStridedParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + uint64_t linear_idx) const { + + // Decompose linear index + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute tensor offset + dst_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank - 1; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices + CUTLASS_DEVICE + ComputeFragment reduce_indices_( + Params const ¶ms, + char *src_byte_ptr) { + + ReductionOp reduction_op(params.reduction_op); + + // Accumulated output + ComputeFragment identity_frag; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(identity_frag.size()); ++i) { + identity_frag[i] = params.reduction_identity; + } + + ComputeFragment accumulator = identity_frag; + ComputeFragment workspace_fragments[kBatchSize]; + + // Partially unrolled loop + for (int idx = 0; idx < params.workspace_count; idx += kBatchSize) { + + // Issue a batch of loads + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + if (idx + b < params.workspace_count) { + workspace_fragments[b] = + *reinterpret_cast(src_byte_ptr); + } + else { + workspace_fragments[b] = identity_frag; + } + src_byte_ptr += + params.workspace_stride; + } + + // Perform a reduction + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorLength; ++i) { + accumulator[i] = reduction_op(accumulator[i], workspace_fragments[b][i]); + } + } + } + + return accumulator; + } + +public: + + // + // Methods + // + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; + + char * src_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); + char * dst_byte_ptr = reinterpret_cast(params.destination + coord_c); + + // If the C index is out of bounds, exit + if (coord_c >= params.extent[kRank - 1]) { + return; + } + + int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + + /// Complete the reduction + while (idx_linear < params.outer_count) { + + int64_t src_byte_offset = idx_linear * params.workspace_outer_stride; + + ComputeFragment result = reduce_indices_( + params, + src_byte_ptr + src_byte_offset); + + // Convert to output type and store + NumericArrayConverter convert_output; + auto cvt = convert_output(result); + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = + reinterpret_cast(cvt); + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..cc354df56a0fd83f0315370138fca729a2236d79 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic thread level reduction with specializations for Array. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/functional.h" + +namespace cutlass { +namespace reduction { +namespace thread { + +/// Structure to compute the thread level reduction +template +struct Reduce; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial Specialization of Reduce for "plus" (a functional operator) +template +struct Reduce< plus, T > { + + CUTLASS_HOST_DEVICE + T operator()(T lhs, T const &rhs) const { + plus _op; + return _op(lhs, rhs); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization of Reduce for Array +template +struct Reduce < plus, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &in) const { + + Array result; + Reduce< plus, T > scalar_reduce; + result.clear(); + + CUTLASS_PRAGMA_UNROLL + for (auto i = 0; i < N; ++i) { + result[0] = scalar_reduce(result[0], in[i]); + } + + return result; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specializations of Reduce for Array +template +struct Reduce < plus, Array > { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &input) { + + Array result; + + // If there is only 1 element - there is nothing to reduce + if( N ==1 ){ + + result[0] = input.front(); + + } else { + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) + + __half result_d; + Array const *in_ptr_half = reinterpret_cast const *>(&input); + Array const *in_ptr_half2 = reinterpret_cast const *>(&input); + __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2); + + // Set initial result = first half2, in case N==2 + __half2 tmp_result = x_in_half2[0]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < N/2; ++i) { + + tmp_result = __hadd2(x_in_half2[i], tmp_result); + + } + + result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result)); + + // One final step is needed for odd "N" (to add the (N-1)th element) + if( N%2 ){ + + __half last_element; + Array tmp_last; + Array *tmp_last_ptr = &tmp_last; + tmp_last_ptr[0] = in_ptr_half[N-1]; + last_element = reinterpret_cast<__half const &>(tmp_last); + + result_d = __hadd(result_d, last_element); + + } + + Array *result_ptr = &result; + *result_ptr = reinterpret_cast &>(result_d); + + #else + + Reduce< plus, half_t > scalar_reduce; + result.clear(); + + CUTLASS_PRAGMA_UNROLL + for (auto i = 0; i < N; ++i) { + + result[0] = scalar_reduce(result[0], input[i]); + + } + + #endif + } + + return result; + + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specializations of Reduce for AlignedArray +template +struct Reduce < plus, AlignedArray > { + + CUTLASS_HOST_DEVICE + Array operator()(AlignedArray const &input) { + + Array result; + + // If there is only 1 element - there is nothing to reduce + if( N ==1 ){ + + result[0] = input.front(); + + } else { + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) + + __half result_d; + AlignedArray const *in_ptr_half = reinterpret_cast const *>(&input); + AlignedArray const *in_ptr_half2 = reinterpret_cast const *>(&input); + __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2); + + // Set initial result = first half2, in case N==2 + __half2 tmp_result = x_in_half2[0]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < N/2; ++i) { + + tmp_result = __hadd2(x_in_half2[i], tmp_result); + + } + + result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result)); + + // One final step is needed for odd "N" (to add the (N-1)th element) + if( N%2 ){ + + __half last_element; + AlignedArray tmp_last; + AlignedArray *tmp_last_ptr = &tmp_last; + tmp_last_ptr[0] = in_ptr_half[N-1]; + last_element = reinterpret_cast<__half const &>(tmp_last); + + result_d = __hadd(result_d, last_element); + + } + + Array *result_ptr = &result; + *result_ptr = reinterpret_cast &>(result_d); + + #else + + Reduce< plus, half_t > scalar_reduce; + result.clear(); + + CUTLASS_PRAGMA_UNROLL + for (auto i = 0; i < N; ++i) { + + result[0] = scalar_reduce(result[0], input[i]); + + } + + #endif + } + + return result; + + } +}; +} +} +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h new file mode 100644 index 0000000000000000000000000000000000000000..3792d332de65f19a1d30ba311d34073201176a3b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h @@ -0,0 +1,235 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over densely packed tensors in global memory +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mixed-precision reduction +template < + typename ElementAccumulator_, + typename Element_, + int Count = 1 +> +struct ReduceAdd { + + // + // Type definitions + // + + using ElementAccumulator = ElementAccumulator_; + using Element = Element_; + static int const kCount = Count; + + using FragmentAccumulator = cutlass::Array; + using FragmentElement = cutlass::Array; + + struct Params { }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + ReduceAdd(Params params_ = Params()): params(params_) { } + + /// Operator + CUTLASS_HOST_DEVICE + FragmentAccumulator operator()( + FragmentAccumulator accumulator, + FragmentElement element) const { + + plus op; + + NumericArrayConverter< + ElementAccumulator, + Element, + kCount, + PreferredRoundingMode::kRound> converter; + + return op(accumulator, converter(element)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Special handling for binary operators +template +struct VectorizeArrayOperation { + + using ValueType = Array; + + CUTLASS_HOST_DEVICE + ValueType operator()( + ReductionOp const &reduction_op, + ValueType const &lhs, + ValueType const &rhs) const { + + ValueType result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = reduction_op(lhs[i], rhs[i]); + } + + return result; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ReduceArrayOperation { + + using ArrayType = Array; + + CUTLASS_HOST_DEVICE + Element operator()( + ReductionOp const &reduction_op, + ArrayType const &array) const { + + Element item = reduction_op(array[0], array[1]); + + CUTLASS_PRAGMA_UNROLL + for (int i = 2; i < N; ++i) { + item = reduction_op(item, array[i]); + } + + return item; + } +}; + +template +struct ReduceArrayOperation, uint1b_t, N> { + + using ArrayType = Array; + + CUTLASS_HOST_DEVICE + uint1b_t operator()( + logical_and const &reduction_op, + ArrayType const &array) const { + + uint8_t const *ptr = reinterpret_cast(&array); + bool item = false; + + CUTLASS_PRAGMA_UNROLL + for (int byte = 0; byte < (N + 7) / 8; ++byte) { + uint8_t bits = ptr[byte]; + item = (item || !bits); + } + + return uint1b_t{!item}; + } +}; + +template +struct ReduceArrayOperation, uint1b_t, N> { + + using ArrayType = Array; + + CUTLASS_HOST_DEVICE + uint1b_t operator()( + logical_and const &reduction_op, + ArrayType const &array) const { + + uint8_t const *ptr = reinterpret_cast(&array); + bool item = true; + + CUTLASS_PRAGMA_UNROLL + for (int byte = 0; byte < (N + 7) / 8; ++byte) { + uint8_t bits = ptr[byte]; + item = (item || bits); + } + + return uint1b_t{item}; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper function to infer template argument types +template +CUTLASS_HOST_DEVICE +Array ApplyArrayOperator( + ReductionOp const &reduction_op, + Array const &lhs, + Array const &rhs) { + + VectorizeArrayOperation vectorize_op; + + return vectorize_op(reduction_op, lhs, rhs); +} + +/// Helper to reduce an array +template +Element ReduceArray(ReductionOp const &reduction_op, Array const &array) { + ReduceArrayOperation reduce_array_op; + + return reduce_array_op(reduction_op, array); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h new file mode 100644 index 0000000000000000000000000000000000000000..bbabaed2736cac7043671f10e9813a9a48b1916c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h @@ -0,0 +1,67 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +/*! \file +\brief Defies functors for mapping blockIdx to partitions of the batched reduction computation. +*/ +#pragma once +#include "cutlass/coord.h" + +namespace cutlass { +namespace reduction { +struct DefaultBlockSwizzle { + /// Ctor + CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {} + + /// Swizzle the block index. + CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } + + /// + CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size, + Coord<3> const &OutputTile) { + assert(OutputTile[0] == 1 && OutputTile[1] == 1); + assert((problem_size[0] * problem_size[1] * problem_size[2]) % OutputTile[2] == 0); + dim3 grid; + grid.x = problem_size[0] * problem_size[1] * problem_size[2] + / OutputTile[2] ; + return grid; + } + + /// + CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &SubTile) { + assert(SubTile[0] == 1 && SubTile[1] == 1); + dim3 block = swizzle(); + Coord<3> threadblock_offset = + make_Coord(0, 0, block.x * SubTile[2]); + return threadblock_offset; + } +}; +} // namespace reduction +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h new file mode 100644 index 0000000000000000000000000000000000000000..68bdb26e38b1a54843eb4883833ad6b8708f0aff --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Performs comparison between two elements with support for floating-point comparisons. +*/ + +#pragma once + +#include "numeric_types.h" +#include "complex.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE +bool relatively_equal(T a, T b, U epsilon, U nonzero_floor); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// This floating-point comparison function implements the method described in +// +// https://floating-point-gui.de/errors/comparison/ +// +template +CUTLASS_HOST_DEVICE +bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { + +#if defined(__CUDACC_RTC__) + using cuda::std::abs; +#else + using std::abs; +#endif + + T abs_A = abs(a); + T abs_B = abs(b); + T diff = abs(a - b); + T zero = T(0); + + if (a == b) { + return true; + } + else if (a == zero || b == zero || (abs_A + abs_B) < nonzero_floor) { + return diff < epsilon * nonzero_floor; + } + + return diff < epsilon * (abs_A + abs_B); +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(bool a, bool b, bool, bool) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int2b_t a, int2b_t b, int2b_t, int2b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint2b_t a, uint2b_t b, uint2b_t, uint2b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int4b_t a, int4b_t b, int4b_t, int4b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint4b_t a, uint4b_t b, uint4b_t, uint4b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int8_t a, int8_t b, int8_t, int8_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint8_t a, uint8_t b, uint8_t, uint8_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int16_t a, int16_t b, int16_t, int16_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint16_t a, uint16_t b, uint16_t, uint16_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int32_t a, int32_t b, int32_t, int32_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint32_t a, uint32_t b, uint32_t, uint32_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int64_t a, int64_t b, int64_t, int64_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint64_t a, uint64_t b, uint64_t, uint64_t) { + return (a == b); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e4m3_t a, float_e4m3_t b, float_e4m3_t epsilon, float_e4m3_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e5m2_t a, float_e5m2_t b, float_e5m2_t epsilon, float_e5m2_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(half_t a, half_t b, half_t epsilon, half_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal( + bfloat16_t a, + bfloat16_t b, + bfloat16_t epsilon, + bfloat16_t nonzero_floor) { + + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal( + tfloat32_t a, + tfloat32_t b, + tfloat32_t epsilon, + tfloat32_t nonzero_floor) { + + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float a, float b, float epsilon, float nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(double a, double b, double epsilon, double nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template +CUTLASS_HOST_DEVICE +bool relatively_equal(complex a, complex b, T epsilon, T nonzero_floor) { +#if defined(__CUDACC_RTC__) + using cuda::std::abs; +#else + using std::abs; +#endif + + T abs_A = abs(a); + T abs_B = abs(b); + T diff = abs(a - b); + complex zero = complex{T{}, T{}}; + + if (a == b) { + return true; + } + else if (a == zero || b == zero || diff < nonzero_floor) { + return diff < epsilon * nonzero_floor; + } + + return diff < epsilon * (abs_A + abs_B); +} + +template +CUTLASS_HOST_DEVICE +bool relatively_equal(complex a, complex b, complex epsilon, complex nonzero_floor) { +#if defined(__CUDACC_RTC__) + using cuda::std::abs; +#else + using std::abs; +#endif + + T abs_A = abs(a); + T abs_B = abs(b); + complex diff = a - b; + T abs_diff = abs(diff); + complex zero = complex{T{}, T{}}; + + if (a == b) { + return true; + } + else if (a == zero || b == zero || abs_diff < abs(nonzero_floor)) { + return abs_diff < abs(epsilon * nonzero_floor); + } + + return abs_diff < abs(epsilon) * (abs_A + abs_B); +} + + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e2m3_t a, float_e2m3_t b, float_e2m3_t epsilon, float_e2m3_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e3m2_t a, float_e3m2_t b, float_e3m2_t epsilon, float_e3m2_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e2m1_t a, float_e2m1_t b, float_e2m1_t epsilon, float_e2m1_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_ue8m0_t a, float_ue8m0_t b, float_ue8m0_t epsilon, float_ue8m0_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_ue4m3_t a, float_ue4m3_t b, float_ue4m3_t epsilon, float_ue4m3_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/semaphore.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/semaphore.h new file mode 100644 index 0000000000000000000000000000000000000000..09a0a1a4572775bbdbdba63a160952e35fef2c20 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/semaphore.h @@ -0,0 +1,118 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implementation of a CTA-wide semaphore for inter-CTA synchronization. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// CTA-wide semaphore for inter-CTA synchronization. +class Semaphore { +public: + + int *lock; + bool wait_thread; + int state; + +public: + + /// Implements a semaphore to wait for a flag to reach a given value + CUTLASS_HOST_DEVICE + Semaphore(int *lock_, int thread_id): + lock(lock_), + wait_thread(thread_id < 0 || thread_id == 0), + state(-1) { + + } + + /// Permit fetching the synchronization mechanism early + CUTLASS_DEVICE + void fetch() { + if (wait_thread) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + #else + asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + #endif + } + } + + /// Gets the internal state + CUTLASS_DEVICE + int get_state() const { + return state; + } + + /// Waits until the semaphore is equal to the given value + CUTLASS_DEVICE + void wait(int status = 0) { + while( __syncthreads_and(state != status) ) { + fetch(); + } + + __syncthreads(); + } + + /// Updates the lock with the given result + CUTLASS_DEVICE + void release(int status = 0) { + __syncthreads(); + + if (wait_thread) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile ("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); + #else + asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); + #endif + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h new file mode 100644 index 0000000000000000000000000000000000000000..6e98cdc3886b06626ea7d003122d62078f7767b9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h @@ -0,0 +1,1388 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Provides a mechanism for packing and unpacking elements smaller than one byte +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/integer_subbyte.h" +#include "cutlass/fast_math.h" + +namespace cutlass { + +namespace detail { +// This is an implementation detail of cutlass::SubbyteReference and. +// cutlass::HostTensor. For a given logical element type Element, +// and its corresponding storage (physical) element type StorageUnit, +// it computes quantities that help with managing allocations. +// +// CUTLASS uses a hidden "ContainerUnitType" or StorageUnit type to support +// packed arrays of subbyte types such as int4. Element is the "logical" type +// for computations, while CUTLASS uses StorageUnit as the element type +// of a packed array of Element. If Element is not a subbyte type, +// then the corresponding StorageUnit type is just Element itself. +// +// The ContainerType is always calculated as an array StorageUnit type (the StorageUnit +// is always a byte for subbyte types), +// and its number of bits is the lcm of the subbyte type's number of bits and 8. +// Below are some examples for different subbyte types. +// +// * Subbyte Type=int2, ContainerType=StorageUnit[1] (StorageUnit=uint8_t) +// * Subbyte Type=int4, ContainerType=StorageUnit[1] (StorageUnit=uint8_t) +template +struct StorageContainerCalculator { + // kContainerTypeNumBits: The number of bits needed for ContainerType + static constexpr int kContainerTypeNumBits = (sizeof_bits::value < 8) ? cutlass::lcm_cxx11(sizeof_bits::value, sizeof_bits::value) : sizeof_bits::value; + static_assert(kContainerTypeNumBits % sizeof_bits::value == 0, "The bits of ContainerType should be divisible by the element's number of bits"); + // kContainerTypeNumLogicalElements: The number of logical Element instance(s) that can be stored per ContainerType instance + static constexpr int kContainerTypeNumLogicalElements = kContainerTypeNumBits / sizeof_bits::value; + /// 3. kContainerTypeNumBytes: The number of bytes per ContainerType instance + static constexpr int kContainerTypeNumBytes = kContainerTypeNumBits / 8; + /// 4. kContainerTypeNumBytes: The number of base StorageUnit in the ContainerType + static constexpr int kContainerTypeNumStorageUnit = kContainerTypeNumBits / sizeof_bits::value; + + static_assert(kContainerTypeNumBits != 0, "kContainerTypeNumBits can not be zero"); + static_assert(kContainerTypeNumLogicalElements != 0, "kContainerTypeNumLogicalElements can not be zero"); + static_assert(kContainerTypeNumBytes != 0, "kContainerTypeNumBytes can not be zero"); +}; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This class provides a mechanism for packing and unpacking elements smaller than one byte. It +/// assumes these sub-byte elements are packed in a traditional C++ numeric type. +/// +/// The intended application is to provide a mechanism to indirectly reference elements in +/// memory or Array<> objects whose addresses cannot otherwise be taken since they are smaller +/// than one byte. +/// +/// Supports basic pointer arithmetic: +/// +/// Example: +/// +/// int4b_t *ptr = ...; +/// +/// SubbyteReference ref = ptr; +/// ref += 15; +/// +/// int4b_t x = ref; // load an int4b_t +/// ref = x + 2_s4; // perform arithmetic on int4b_t and then store +/// +template < + typename Element_, /// CUTLASS numeric element type. + typename Storage_ = uint8_t, /// Underlying storage type. Must be able to hold an integer + /// number of objects of type Element. + class = void +> +class ConstSubbyteReference { +public: + + using Element = Element_; + using Storage = Storage_; + using StoragePointer = Storage const *; + + static_assert(sizeof_bits::value <= sizeof_bits::value, + "Size of Element must not be greater than Storage."); + + static_assert(!(sizeof_bits::value % sizeof_bits::value), + "Storage must be divisible by Element"); + +private: + + ///! Number of elements per storage vector + int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; + + ///! Bit mask + Storage const kMask = + ((sizeof_bits::value < sizeof_bits::value) ? + (Storage(1) << sizeof_bits::value) - Storage(1) : + ~Storage(0)); + +private: + + /// Pointer to array containing element + StoragePointer ptr_; + + /// Offset (in units of elements) from pointer. + /// + /// Invariant: must always be in range [0, kElementsPerVector) + int offset_; + +public: + + CUTLASS_HOST_DEVICE + ConstSubbyteReference(): ptr_(nullptr), offset_(0) { } + + /// Constructor + CUTLASS_HOST_DEVICE + ConstSubbyteReference( + Element const *ptr, /// pointer to memory + int64_t offset /// logical offset in units of Element + ): + ptr_(reinterpret_cast(ptr)), + offset_(0) { + + int64_t offset_in_vectors = offset / kElementsPerVector; + int64_t offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = int(offset_in_elements); + } + + /// Constructor + CUTLASS_HOST_DEVICE + ConstSubbyteReference( + Element *ptr = nullptr + ): ConstSubbyteReference(ptr, 0) { } + + /// Gets storage pointer + CUTLASS_HOST_DEVICE + StoragePointer storage_pointer() const { + return ptr_; + } + + /// Gets element offset within storage vector + CUTLASS_HOST_DEVICE + int element_offset() const { + return offset_; + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + Element get() const { + Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask); + return reinterpret_cast(item); + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + operator Element() const { + return get(); + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator+=(int offset) { + + offset += offset_; + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator+=(long long offset) { + + offset += offset_; + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator-=(int offset) { + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator-=(long long offset) { + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + return *this; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator+(int offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator+(long long offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator-(int offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator-=(long long offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Computes the difference in elements between references + CUTLASS_HOST_DEVICE + ptrdiff_t operator-(ConstSubbyteReference ref) const { + return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to signed 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator int64_t() const { + return int64_t(get()); + } + + /// Explicit cast to unsigned 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const { + return uint64_t(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + + /// Explicit cast to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(get()); + } +}; + +template < + typename Element_, /// CUTLASS numeric element type. + typename Storage_ = /// Underlying storage type. Must be able to hold an integer + /// number of objects of type Element. + +#if defined(__CUDA_ARCH__) /// Default size depends on width of atomicCas() overloads. + #if (__CUDA_ARCH__ >= 700) /// + uint16_t + #else + uint32_t + #endif +#else + uint8_t +#endif + , + class = void +> +class SubbyteReference { +public: + + using Element = Element_; + using Storage = Storage_; + using StoragePointer = Storage *; + + static_assert(sizeof_bits::value <= sizeof_bits::value, + "Size of Element must not be greater than Storage."); + + static_assert(!(sizeof_bits::value % sizeof_bits::value), + "Storage must be divisible by Element"); + +private: + + ///! Number of elements per storage vector + int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; + + ///! Bit mask + Storage const kMask = + ((sizeof_bits::value < sizeof_bits::value) ? + (Storage(1) << sizeof_bits::value) - Storage(1) : + ~Storage(0)); + +private: + + /// Pointer to array containing element + StoragePointer ptr_; + + /// Offset (in units of elements) from pointer. + /// + /// Invariant: must always be in range [0, kElementsPerVector) + int offset_; + +public: + + CUTLASS_HOST_DEVICE + SubbyteReference(): ptr_(nullptr), offset_(0) { } + + /// Constructor + CUTLASS_HOST_DEVICE + SubbyteReference( + Element *ptr, /// pointer to memory + int64_t offset /// logical offset in units of Element + ): + ptr_(reinterpret_cast(ptr)), + offset_(0) { + + int64_t offset_in_vectors = offset / kElementsPerVector; + int64_t offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = int(offset_in_elements); + } + + /// Constructor + CUTLASS_HOST_DEVICE + SubbyteReference( + Element *ptr = nullptr + ): SubbyteReference(ptr, 0) { } + + /// Gets storage pointer + CUTLASS_HOST_DEVICE + StoragePointer storage_pointer() const { + return ptr_; + } + + /// Gets storage pointer + CUTLASS_HOST_DEVICE + Element * operator&() const { + return reinterpret_cast(ptr_); + } + + /// Gets element offset within storage vector + CUTLASS_HOST_DEVICE + int element_offset() const { + return offset_; + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + Element get() const { + uint8_t const* byte_ptr = reinterpret_cast(ptr_); + // Convert offset in elements to offset in bytes + constexpr int elements_per_byte = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value; + byte_ptr += offset_ / elements_per_byte; + // Offset of element within a byte + int byte_offset = offset_ % elements_per_byte; + uint8_t item = uint8_t((*byte_ptr >> (byte_offset * cutlass::sizeof_bits::value)) & kMask); + return reinterpret_cast(item); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference & set(Element const &x) { + + Storage item = (reinterpret_cast(x) & kMask); + Storage kUpdateMask = Storage(~(kMask << (offset_ * cutlass::sizeof_bits::value))); + Storage new_bits = Storage(item << (offset_ * cutlass::sizeof_bits::value)); + +#if defined(__CUDA_ARCH__) + + // + // Homebrew read-modify-write + // + Storage original; + Storage updated; + + do { + + original = (*ptr_); + + updated = Storage((original & kUpdateMask) | new_bits); + + original = atomicCAS(ptr_, original, updated); + + } while (updated != original); + +#else + + Storage original = (*ptr_); + Storage updated = Storage((original & kUpdateMask) | new_bits); + *ptr_ = updated; + +#endif + + return *this; + } + + //// + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + operator Element() const { + return get(); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=(Element const & x) { + return set(x); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=(SubbyteReference const & x) { + return set(x.get()); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=( + ConstSubbyteReference const &x) { + return set(x.get()); + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator+=(int offset) { + + offset += offset_; + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator+=(long long offset) { + + offset += offset_; + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator-=(int offset) { + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator-=(long long offset) { + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + return *this; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator+(int offset) const { + + SubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator+(long long offset) const { + + SubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator-(int offset) const { + + SubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator-=(long long offset) const { + + SubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Computes the difference in elements between references + CUTLASS_HOST_DEVICE + ptrdiff_t operator-(SubbyteReference ref) const { + return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to signed 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator int64_t() const { + return int64_t(get()); + } + + /// Explicit cast to unsigned 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const { + return uint64_t(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + + /// Explicit cast to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(get()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template using _war = T; +template < + typename Element_, /// CUTLASS numeric element type. + typename Storage_ /// Underlying basic storage type. +> +class SubbyteReference::value % sizeof_bits::value != 0>::type> { +public: + + using Element = Element_; + /// Note: It's possible that StorageUnit is not divisible by Element. + /// For example, an Element instance might be stored across 2 StorageUnit instances. + /// Thus, CUTLASS needs a storage vector to hold an integer number of Element instances. + + using StorageUnit = Storage_; +private: + using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; +public: + static int const kBitsStoredVec = StorageContainerCalculator::kContainerTypeNumBits; + static int const kNumStorageUnitPerStoredVec = StorageContainerCalculator::kContainerTypeNumStorageUnit; + + using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec]; + using StorageVecPointer = StorageVec *; + + using CudaAtomicType = typename platform::conditional< + sizeof_bits::value == 16, + uint32_t, + uint64_t + >::type; + + static_assert(sizeof_bits::value <= sizeof_bits::value, + "Size of Element must not be greater than StorageVec."); + + static_assert(!(sizeof_bits::value % sizeof_bits::value), + "StorageVec must be divisible by Element"); + +private: + + ///! Number of elements per storage vector + int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; + + ///! Bit mask for storage unit. + StorageUnit const kMask = (StorageUnit(1) << sizeof_bits::value) - StorageUnit(1); + + /// Pointer to array containing element + _war ptr_; + + /// Offset (in units of elements) from pointer. + /// + /// Invariant: must always be in range [0, kElementsPerVector) + int offset_; + + /// Element may be stored across 2 storage unit. + /// Low storage unit index in StorageVec + /// High storage unit index in StorageVec + int low_storage_unit_idx_; + int high_storage_unit_idx_; + + /// Full Mask to extract the entire element + uint64_t full_element_mask_; + + /// Mask to extract the Element from Low storage unit and High storage unit. + StorageUnit low_storage_mask_; + StorageUnit high_storage_mask_; + + /// Start bit index inside the storage unit. + int start_bit_idx_; + +private: + + CUTLASS_HOST_DEVICE + void update_element_status() { + int num_bits = offset_ * sizeof_bits::value; + + start_bit_idx_ = num_bits % sizeof_bits::value; + + low_storage_unit_idx_ = num_bits / sizeof_bits::value; + high_storage_unit_idx_ = sizeof_bits::value - (start_bit_idx_) < sizeof_bits::value + ? low_storage_unit_idx_ + 1 : low_storage_unit_idx_; + + full_element_mask_ = uint64_t(kMask) << start_bit_idx_; + low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0)); + high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits::value) & ~StorageUnit(0)); + } + +public: + + CUTLASS_HOST_DEVICE + SubbyteReference(): ptr_(nullptr), offset_(0) { } + + /// Constructor + CUTLASS_HOST_DEVICE + SubbyteReference( + Element *ptr, /// pointer to memory + int64_t offset /// logical offset in units of Element + ): + ptr_(reinterpret_cast(ptr)), + offset_(0) { + int64_t offset_in_vectors = offset / kElementsPerVector; + int64_t offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = int(offset_in_elements); + + update_element_status(); + } + + /// Constructor + CUTLASS_HOST_DEVICE + SubbyteReference( + Element *ptr = nullptr + ): SubbyteReference(ptr, 0) { } + + /// Gets StorageVec pointer + CUTLASS_HOST_DEVICE + StorageVecPointer storage_pointer() const { + return ptr_; + } + + /// Gets StorageVec pointer + CUTLASS_HOST_DEVICE + Element * operator&() const { + return reinterpret_cast(ptr_); + } + + /// Gets element offset within StorageVec vector + CUTLASS_HOST_DEVICE + int element_offset() const { + return offset_; + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + Element get() const { + StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_; + StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0; + + uint64_t full_item = ((uint64_t)high_bits << sizeof_bits::value) | low_bits; + uint8_t result = uint8_t(full_item >> start_bit_idx_); + + return reinterpret_cast(result); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference & set(Element const &x) { + + uint64_t item = static_cast((reinterpret_cast(x) & kMask)) << start_bit_idx_; + + StorageUnit low_new_bits = StorageUnit(item & ~StorageUnit(0)); + StorageUnit high_new_bits = StorageUnit(item >> sizeof_bits::value); + + StorageUnit const kLowUpdateMask = StorageUnit((~full_element_mask_) & (~StorageUnit(0))); + StorageUnit const kHighUpdateMask = StorageUnit(((~full_element_mask_) >> sizeof_bits::value) & (~StorageUnit(0))); + +#if defined(__CUDA_ARCH__) + // + // Homebrew read-modify-write + // + if(high_storage_unit_idx_ != low_storage_unit_idx_){ + /// Only need update 2 storage unit at once. + /// consider misaligned address issue, we need to do atomicCAS twice + StorageUnit original_low_bits, original_high_bits, update_low_bits, update_high_bits; + do { + original_low_bits = ((*ptr_)[low_storage_unit_idx_]); + update_low_bits = (original_low_bits & kLowUpdateMask) | low_new_bits; + original_low_bits = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original_low_bits, update_low_bits); + } while (update_low_bits != original_low_bits); + do { + original_high_bits = ((*ptr_)[high_storage_unit_idx_]); + update_high_bits = (original_high_bits & kHighUpdateMask) | high_new_bits; + original_high_bits = atomicCAS(&((*ptr_)[high_storage_unit_idx_]), original_high_bits, update_high_bits); + } while (update_high_bits != original_high_bits); + } + else { + /// Only need update 1 storage unit. + StorageUnit original, updated; + do { + original = ((*ptr_)[low_storage_unit_idx_]); + + updated = (original & kLowUpdateMask) | low_new_bits; + + original = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original, updated); + + } while (updated != original); + } +#else + + + StorageUnit update_low_bits = ((*ptr_)[low_storage_unit_idx_] & kLowUpdateMask) | low_new_bits; + StorageUnit update_high_bits = ((*ptr_)[high_storage_unit_idx_] & kHighUpdateMask) | high_new_bits; + + (*ptr_)[low_storage_unit_idx_] = update_low_bits; + + if(low_storage_unit_idx_ != high_storage_unit_idx_) + (*ptr_)[high_storage_unit_idx_] = update_high_bits; +#endif + + return *this; + } + + //// + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + operator Element() const { + return get(); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=(Element const & x) { + return set(x); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=(SubbyteReference const & x) { + return set(x.get()); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=( + ConstSubbyteReference const &x) { + return set(x.get()); + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator+=(int offset) { + + offset += offset_; + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator+=(long long offset) { + + offset += offset_; + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator-=(int offset) { + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator-=(long long offset) { + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + return *this; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator+(int offset) const { + + SubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator+(long long offset) const { + + SubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator-(int offset) const { + + SubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator-=(long long offset) const { + + SubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Computes the difference in elements between references + CUTLASS_HOST_DEVICE + ptrdiff_t operator-(SubbyteReference ref) const { + return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to signed 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator int64_t() const { + return int64_t(get()); + } + + /// Explicit cast to unsigned 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const { + return uint64_t(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + + /// Explicit cast to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(get()); + } +}; + +template using _war = T; +template < + typename Element_, /// CUTLASS numeric element type. + typename Storage_ /// Underlying storage type. Must be able to hold an integer +> +class ConstSubbyteReference::value % sizeof_bits::value != 0>::type> { +public: + + using Element = Element_; + ///! Note: Storage unit could not be divisibale by Element, + /// Type element may be stored across 2 storage units, so need a storage vector to hold integer + /// number of objects of type Element. + using StorageUnit = Storage_; + static int const kBitsStoredVec = cutlass::lcm_cxx11(sizeof_bits::value, sizeof_bits::value); + static int const kNumStorageUnitPerStoredVec = kBitsStoredVec / sizeof_bits::value; + + using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec]; + using StorageVecPointer = StorageVec const *; + + using CudaAtomicType = typename platform::conditional< + sizeof_bits::value == 16, + uint32_t, + uint64_t + >::type; + + static_assert(sizeof_bits::value <= sizeof_bits::value, + "Size of Element must not be greater than StorageVec."); + + static_assert(!(sizeof_bits::value % sizeof_bits::value), + "StorageVec must be divisible by Element"); + +private: + + ///! Number of elements per storage vector + int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; + + ///! Bit mask for storage unit. + StorageUnit const kMask = (StorageUnit(1) << sizeof_bits::value) - StorageUnit(1); + + /// Pointer to array containing element + _war ptr_; + + /// Offset (in units of elements) from pointer. + /// + /// Invariant: must always be in range [0, kElementsPerVector) + int offset_; + + /// Element may be stored across 2 storage unit. + /// Low storage unit index in StorageVec + /// High storage unit index in StorageVec + int low_storage_unit_idx_; + int high_storage_unit_idx_; + + /// Full Mask to extract the entire element + uint64_t full_element_mask_; + + /// Mask to extract the Element from Low storage unit and High storage unit. + StorageUnit low_storage_mask_; + StorageUnit high_storage_mask_; + + /// Start bit index inside the storage unit. + int start_bit_idx_; + +private: + + CUTLASS_HOST_DEVICE + void update_element_status() { + int num_bits = offset_ * sizeof_bits::value; + + start_bit_idx_ = num_bits % sizeof_bits::value; + + low_storage_unit_idx_ = num_bits / sizeof_bits::value; + high_storage_unit_idx_ = sizeof_bits::value - (start_bit_idx_) < sizeof_bits::value + ? low_storage_unit_idx_ + 1 : low_storage_unit_idx_; + + full_element_mask_ = uint64_t(kMask) << start_bit_idx_; + low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0)); + high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits::value) & ~StorageUnit(0)); + } + +public: + + CUTLASS_HOST_DEVICE + ConstSubbyteReference(): ptr_(nullptr), offset_(0) { } + + /// Constructor + CUTLASS_HOST_DEVICE + ConstSubbyteReference( + Element const *ptr, /// pointer to memory + int64_t offset /// logical offset in units of Element + ): + ptr_(reinterpret_cast(ptr)), + offset_(0) { + + int64_t offset_in_vectors = offset / kElementsPerVector; + int64_t offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = int(offset_in_elements); + + update_element_status(); + } + + /// Constructor + CUTLASS_HOST_DEVICE + ConstSubbyteReference( + Element *ptr = nullptr + ): ConstSubbyteReference(ptr, 0) { } + + /// Gets storage pointer + CUTLASS_HOST_DEVICE + StorageVecPointer storage_pointer() const { + return ptr_; + } + + /// Gets element offset within storage vector + CUTLASS_HOST_DEVICE + int element_offset() const { + return offset_; + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + Element get() const { + StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_; + StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0; + + uint64_t full_item = ((uint64_t)high_bits << sizeof_bits::value) | low_bits; + uint8_t result = uint8_t(full_item >> start_bit_idx_); + + return reinterpret_cast(result); + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + operator Element() const { + return get(); + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator+=(int offset) { + + offset += offset_; + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator+=(long long offset) { + + offset += offset_; + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator-=(int offset) { + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator-=(long long offset) { + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + + return *this; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator+(int offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator+(long long offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator-(int offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator-=(long long offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Computes the difference in elements between references + CUTLASS_HOST_DEVICE + ptrdiff_t operator-(ConstSubbyteReference ref) const { + return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to signed 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator int64_t() const { + return int64_t(get()); + } + + /// Explicit cast to unsigned 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const { + return uint64_t(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + + /// Explicit cast to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(get()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template ::value < 8)> +struct ReferenceFactory; + +template +struct ReferenceFactory { + + ///! Number of elements per storage vector + static int const kElementsPerVector = 1; + + CUTLASS_HOST_DEVICE + static Element &get(Element *ptr, int64_t offset) { + return ptr[offset]; + } + + CUTLASS_HOST_DEVICE + static Element const &get(Element const *ptr, int64_t offset) { + return ptr[offset]; + } + + CUTLASS_HOST_DEVICE + static Element *add_pointer_offset(Element *ptr, int64_t offset) { + return ptr + offset; + } + + CUTLASS_HOST_DEVICE + static Element const *add_pointer_offset(Element const *ptr, int64_t offset) { + return ptr + offset; + } +}; + +template +struct ReferenceFactory { + + // + // Static methods + // + + CUTLASS_HOST_DEVICE + static SubbyteReference get(Element *ptr, int64_t offset) { + return SubbyteReference(ptr, offset); + } + + CUTLASS_HOST_DEVICE + static ConstSubbyteReference get(Element const *ptr, + int64_t offset) { + return ConstSubbyteReference(ptr, offset); + } + + /// Helper to add an offset in number of elements, assuming this offset is divisible + /// by the vector size. + CUTLASS_HOST_DEVICE + static Element *add_pointer_offset(Element *ptr, int64_t offset_in_elements) { + return &SubbyteReference(ptr, offset_in_elements); + } + + /// Helper to add an offset in number of elements, assuming this offset is divisible + /// by the vector size. + CUTLASS_HOST_DEVICE + static Element const *add_pointer_offset(Element const *ptr, int64_t offset_in_elements) { + return &ConstSubbyteReference(ptr, offset_in_elements); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h new file mode 100644 index 0000000000000000000000000000000000000000..a124d395cf2222331e0ceb160271b1621688fd6f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a canonical coordinate for rank=4 tensors offering named indices. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a canonical 4D coordinate used by tensor operations. +struct Tensor4DCoord : public Coord<4> { + + /// Base class + using Base = Coord<4>; + + /// Index type + using Index = typename Base::Index; + + /// LongIndex type + using LongIndex = typename Base::LongIndex; + + /// Batch dimension + static int const kN = 0; + + /// Height dimension + static int const kH = 1; + + /// Width dimension + static int const kW = 2; + + /// Channels dimension + static int const kC = 3; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Tensor4DCoord() { } + + /// Constructs from Coord<4> + CUTLASS_HOST_DEVICE + Tensor4DCoord(Coord<4> const &coord): Base(coord) { } + + /// Helper to construct from N, H, W, and C. + CUTLASS_HOST_DEVICE + Tensor4DCoord(Index n, Index h, Index w, Index c): Base(make_Coord(n, h, w, c)) { } + + /// Helper to construct from N, H, W, and C, which are LongIndex type + CUTLASS_HOST_DEVICE + Tensor4DCoord(LongIndex n, LongIndex h, LongIndex w, LongIndex c) + : Base(make_Coord(Index(n), Index(h), Index(w), Index(c))) { } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index const & n() const { return this->at(kN); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index & n() { return this->at(kN); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index const & h() const { return this->at(kH); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index & h() { return this->at(kH); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index const & w() const { return this->at(kW); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index & w() { return this->at(kW); } + + /// Returns the channel of the coordinate + CUTLASS_HOST_DEVICE + Index const & c() const { return this->at(kC); } + + /// Returns the channel of the coordinate + CUTLASS_HOST_DEVICE + Index & c() { return this->at(kC); } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + Tensor4DCoord operator+(Base const& b) const { + return Tensor4DCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + Tensor4DCoord operator-(Base const& b) const { + return Tensor4DCoord(Base::operator-(b)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + Tensor4DCoord operator*(Base const& b) const { + return Tensor4DCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + Tensor4DCoord operator/(Base const& b) const { + return Tensor4DCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + Tensor4DCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + Tensor4DCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + Tensor4DCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + Tensor4DCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a canonical 5D coordinate used by tensor operations. +struct Tensor5DCoord : public Coord<5> { + + /// Base class + using Base = Coord<5>; + + /// Index type + using Index = typename Base::Index; + + /// LongIndex type + using LongIndex = typename Base::LongIndex; + + /// Batch dimension + static int const kN = 0; + + /// Depth dimension + static int const kD = 1; + + /// Height dimension + static int const kH = 2; + + /// Width dimension + static int const kW = 3; + + /// Channels dimension + static int const kC = 4; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Tensor5DCoord() { } + + /// Constructs from Coord<5> + CUTLASS_HOST_DEVICE + Tensor5DCoord(Coord<5> const &coord): Base(coord) { } + + /// Helper to construct from N, D, H, W, and C. + CUTLASS_HOST_DEVICE + Tensor5DCoord(Index n, Index d, Index h, Index w, Index c): Base(make_Coord(n, d, h, w, c)) { } + + /// Helper to construct from N, D, H, W, and C, which are LongIndex type + CUTLASS_HOST_DEVICE + Tensor5DCoord(LongIndex n, LongIndex d, LongIndex h, LongIndex w, LongIndex c) + : Base(make_Coord(Index(n), Index(d), Index(h), Index(w), Index(c))) { } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index const & n() const { return this->at(kN); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index & n() { return this->at(kN); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index const & d() const { return this->at(kD); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index & d() { return this->at(kD); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index const & h() const { return this->at(kH); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index & h() { return this->at(kH); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index const & w() const { return this->at(kW); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index & w() { return this->at(kW); } + + /// Returns the channel of the coordinate + CUTLASS_HOST_DEVICE + Index const & c() const { return this->at(kC); } + + /// Returns the channel of the coordinate + CUTLASS_HOST_DEVICE + Index & c() { return this->at(kC); } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + Tensor5DCoord operator+(Base const& b) const { + return Tensor5DCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + Tensor5DCoord operator-(Base const& b) const { + return Tensor5DCoord(Base::operator-(b)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + Tensor5DCoord operator*(Base const& b) const { + return Tensor5DCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + Tensor5DCoord operator/(Base const& b) const { + return Tensor5DCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h new file mode 100644 index 0000000000000000000000000000000000000000..fc467499996a00645b0a936efe741ece2092fb90 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h @@ -0,0 +1,419 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides, bounds, and a pointer to tensor data. +*/ +#pragma once + + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" +#include "cutlass/platform/platform.h" +#include "cutlass/subbyte_reference.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default layout function from coordinates in a tensor's index space into the n-D array held +/// in memory. +/// +/// All layout functions must define at least the members shown in IdentityTensorLayout<>. +template +class IdentityTensorLayout { +public: + /// Logical rank of tensor + static int const kRank = Rank; + + /// Rank of stride vector + static int const kStrideRank = Rank; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = Coord; + + /// Stride vector + using Stride = Coord; + +private: + + // + // Data members + // + + /// Stride data member + Stride stride_; + +public: + + // + // Methods + // + + CUTLASS_HOST_DEVICE + IdentityTensorLayout(Stride const &stride = Stride()): stride_(stride) { } + + /// Returns the offset of a coordinate in linear memory + CUTLASS_HOST_DEVICE + LongIndex operator()(Coord const &coord) const { + return coord.dot(stride_); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &size) const { + int idx = stride_.max_dim_index(); + return stride_[idx] * size[idx]; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank + and layout within memory. A TensorRef combines a pointer and a Layout concept + + Examples: + + (These examples use helpers for matrix layouts defined in cutlass/layout/matrix.h) + + 1. Column-major matrix may be represented as a rank=2 tensor: + + TensorRef A(ptr_A, ldm); + + 2. Row-major matrix may be represented as a rank=2 tensor: + + TensorRef B(ptr_A, ldm); + + 3. An interleaved matrix may be represented as a rank=2 tensor: + + TensorRef > C; + + 4. A helper exists to define a TensorRef for a contiguous matrix whose layout + is not known at compile time. + + int ldm; // leading dimension + layout::Matrix kind; // Could be layout::Matrix::kRowMajor or layout::Matrix::kColumnMajor + + + TensorRef E(ptr_E, {ldm, kind}); + +*/ +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class TensorRef { + public: + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Reference type to an element + using Reference = typename platform::conditional< + sizeof_bits::value >= 8, + Element &, + SubbyteReference + >::type; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// TensorRef to constant data + using ConstTensorRef = TensorRef< + typename platform::remove_const::type const, + Layout>; + + /// TensorRef to non-constant data + using NonConstTensorRef = TensorRef< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// Pointer + Element* ptr_; + + /// Layout object maps logical coordinates to linear offsets + Layout layout_; + + public: + + // + // Methods + // + + /// Constructs a TensorRef with a pointer and layout object. + CUTLASS_HOST_DEVICE + TensorRef(): ptr_(nullptr) { + + } + + /// Constructs a TensorRef with a pointer and layout object. + CUTLASS_HOST_DEVICE + TensorRef( + Element *ptr, ///< pointer to start of tensor + Layout const &layout ///< layout object containing stride and mapping function + ): + ptr_(ptr), layout_(layout) { + + } + + /// Converting constructor from TensorRef to non-constant data. + template + CUTLASS_HOST_DEVICE + TensorRef( + NonConstTensorRef const &ref, ///< TensorRef to non-const data + ///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const + _Magic magic = (typename platform::enable_if< ! platform::is_same >::value, _Magic>::type)0 + ): + ptr_(ref.data()), layout_(ref.layout()) { } + + /// Returns a reference to constant-valued tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(ptr_, layout_); + } + + CUTLASS_HOST_DEVICE + NonConstTensorRef non_const_ref() const { + return NonConstTensorRef(const_cast::type *>(ptr_), layout_); + } + + /// Updates only the pointer + CUTLASS_HOST_DEVICE + void reset(Element* ptr = nullptr) { + ptr_ = ptr; + } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout) { + ptr_ = ptr; + layout_ = layout; + } + + /// Returns true if the TensorRef is non-null + CUTLASS_HOST_DEVICE + bool good() const { + return ptr_ != nullptr; + } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Element * data() const { return ptr_; } + + /// Returns a reference to the element at a given linear index + CUTLASS_HOST_DEVICE + Reference data(LongIndex idx) const { + return ReferenceFactory::type, + (sizeof_bits::value < 8)>::get(ptr_, idx); + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout & layout() { + return layout_; + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + typename Layout::Stride::Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + typename Layout::Stride::Index & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + CUTLASS_HOST_DEVICE + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference at(TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference operator[](TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRef & add_pointer_offset(LongIndex offset_) { + ptr_ = ReferenceFactory::type, + (sizeof_bits::value < 8)>::add_pointer_offset(ptr_, offset_); + return *this; + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRef & add_coord_offset(TensorCoord const &coord) { + add_pointer_offset(offset(coord)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef operator+(TensorCoord const& b) const { + TensorRef result(*this); + result.add_coord_offset(b); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef & operator+=(TensorCoord const& b) { + add_coord_offset(b); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef operator-(TensorCoord const& b) const { + TensorRef result(*this); + result.add_pointer_offset(-offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef & operator-=(TensorCoord const& b) { + add_pointer_offset(-offset(b)); + return *this; + } +}; + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE +TensorRef make_TensorRef(Element *ptr, Layout const &layout) { + return TensorRef(ptr, layout); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations to handle degenerate and sub-byte cases. +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE +bool TensorRef_aligned(TensorRef const &ref, int alignment) { + + int const kStrideRank = Layout::kStrideRank; + + if (reinterpret_cast(ref.data()) % alignment) { + return false; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStrideRank; ++i) { + if (ref.stride(i) % alignment) { + return false; + } + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..9ba3a2308081e8c4b11d18cb8125ec7943e534f0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h @@ -0,0 +1,374 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides, bounds, and a pointer to tensor data. +*/ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct PlanarComplexReference { + + // + // Type definitions + // + + using Element = Element_; + using ComplexElement = complex; + + // + // Data members + // + + Element *real; + Element *imag; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + PlanarComplexReference( + Element *real_ = nullptr, + Element *imag_ = nullptr + ): + real(real_), imag(imag_) { } + + /// Loads the complex element + CUTLASS_HOST_DEVICE + operator complex() const { + return complex{*real, *imag}; + } + + /// Stores a complex element to the location pointed to by the reference + CUTLASS_HOST_DEVICE + PlanarComplexReference &operator=(complex const &rhs) { + *real = rhs.real(); + *imag = rhs.imag(); + return *this; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank + and layout within memory. A TensorRef combines a pointer and a Layout concept + +*/ +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class TensorRefPlanarComplex { + public: + /// Data type of individual access + using Element = Element_; + + /// Complex element type + using ComplexElement = complex; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + static_assert(sizeof_bits::value >= 8, + "Planar complex not suitable for subbyte elements at this time"); + + /// Reference type to an element + using Reference = PlanarComplexReference; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// TensorRef to constant data + using ConstTensorRef = TensorRefPlanarComplex< + typename platform::remove_const::type const, + Layout>; + + /// TensorRef to non-constant data + using NonConstTensorRef = TensorRefPlanarComplex< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// Pointer + Element* ptr_; + + /// Layout object maps logical coordinates to linear offsets + Layout layout_; + + /// Offset to imaginary part + LongIndex imaginary_stride_; + + public: + + // + // Methods + // + + /// Constructs a TensorRef with a pointer and layout object. + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex( + Element *ptr = nullptr, ///< pointer to start of tensor + Layout const &layout = Layout(), ///< layout object containing stride and mapping function + LongIndex imaginary_stride = 0 + ): + ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) { + + } + + /// Converting constructor from TensorRef to non-constant data. + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex( + NonConstTensorRef const &ref ///< TensorRef to non-const data + ): + ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { } + + /// Returns a reference to constant-valued tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(ptr_, layout_, imaginary_stride_); + } + + CUTLASS_HOST_DEVICE + NonConstTensorRef non_const_ref() const { + return NonConstTensorRef( + const_cast::type *>(ptr_), + layout_, + imaginary_stride_); + } + + /// Updates only the pointer + CUTLASS_HOST_DEVICE + void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) { + ptr_ = ptr; + imaginary_stride_ = imaginary_stride; + } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) { + ptr_ = ptr; + layout_ = layout; + imaginary_stride_ = imaginary_stride; + } + + /// Returns true if the TensorRef is non-null + CUTLASS_HOST_DEVICE + bool good() const { + return ptr_ != nullptr; + } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Element * data() const { return ptr_; } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Element * imaginary_data() const { return ptr_ + imaginary_stride_; } + + /// Returns a reference to the element at a given linear index + CUTLASS_HOST_DEVICE + Reference data(LongIndex idx) const { + return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_); + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout & layout() { + return layout_; + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout layout() const { + return layout_; + } + + /// Gets the stride to an imaginary element + LongIndex imaginary_stride() const { + return imaginary_stride_; + } + + /// Gets the stride to an imaginary element + LongIndex &imaginary_stride() { + return imaginary_stride_; + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + Index & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + CUTLASS_HOST_DEVICE + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference at(TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference operator[](TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) { + ptr_ += offset_; + return *this; + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) { + add_pointer_offset(offset(coord)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex operator+(TensorCoord const& b) const { + TensorRefPlanarComplex result(*this); + result.add_coord_offset(b); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & operator+=(TensorCoord const& b) { + add_coord_offset(b); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex operator-(TensorCoord const& b) const { + TensorRefPlanarComplex result(*this); + result.add_pointer_offset(-offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & operator-=(TensorCoord const& b) { + add_pointer_offset(-offset(b)); + return *this; + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorRef ref_real() const { + return cutlass::TensorRef(data(), layout()); + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorRef ref_imag() const { + return cutlass::TensorRef(imaginary_data(), layout()); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE +TensorRefPlanarComplex make_TensorRefPlanarComplex( + Element *ptr, + Layout const &layout, + int64_t imaginary_stride) { + + return TensorRefPlanarComplex(ptr, layout, imaginary_stride); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h new file mode 100644 index 0000000000000000000000000000000000000000..d669443abd8b5b246a9d2aaf2ce4dd91f782f948 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h @@ -0,0 +1,297 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides and a pointer to tensor data. + + TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, + it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from + data storage and is therefore lightweight and may be embedded in larger tensor objects or + memory structures. + + See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to + linear memory. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Data type of element stored within tensor + typename Element_, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename Layout_ +> +class TensorView : public TensorRef { + public: + + /// Base tensor reference + using Base = cutlass::TensorRef; + + /// Mapping function from logical coordinate to internal n-D array + using Layout = Layout_; + + /// TensorRef pointing to constant memory + using ConstTensorRef = typename Base::ConstTensorRef; + + /// Underlying TensorRef type + using TensorRef = Base; + + /// Data type of individual access + using Element = Element_; + + /// Reference type to an element + using Reference = Element &; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Coordinate in storage n-D array + using Stride = typename Layout::Stride; + + /// TensorView pointing to constant memory + using ConstTensorView = TensorView< + typename platform::remove_const::type const, + Layout>; + + /// TensorView pointing to non-constant memory + using NonConstTensorView = TensorView< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// View extent + TensorCoord extent_; + + public: + + // + // Methods + // + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorView() { } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorView( + Element *ptr, ///< pointer to start of tensor + Layout const &layout, ///< layout object containing stride and mapping function + TensorCoord const &extent ///< size of the view in logical coordinates + ): + Base(ptr, layout), extent_(extent) { + + } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorView( + TensorRef const &ref, ///< pointer and layout object referencing a tensor + TensorCoord const &extent ///< logical size of tensor + ): + Base(ref), extent_(extent) { + + } + + /// Converting constructor from TensorRef to non-constant data. + CUTLASS_HOST_DEVICE + TensorView( + NonConstTensorView const &view ///< TensorView to non-const data + ): + Base(view), extent_(view.extent_) { } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout, TensorCoord const &extent) { + Base::reset(ptr, layout); + this->resize(extent); + } + + /// Updates the pointer + CUTLASS_HOST_DEVICE + void reset(Element* ptr) { + Base::reset(ptr); + } + + /// Changes the size of the view without affecting pointer or layout + CUTLASS_HOST_DEVICE + void resize(TensorCoord const &extent) { + this->extent_ = extent; + } + + /// Returns the extent of the view (the size along each logical dimension). + CUTLASS_HOST_DEVICE + TensorCoord const& extent() const { return extent_; } + + /// Returns the extent along a particular logical dimension. + CUTLASS_HOST_DEVICE + Index extent(int dim) const { return extent_.at(dim); } + + /// Returns the number of logical elements + CUTLASS_HOST_DEVICE + LongIndex size() const { + return extent_.product(); + } + + /// Determines whether a location is within a tensor + CUTLASS_HOST_DEVICE + bool contains(TensorCoord const& coord) const { + CUTLASS_PRAGMA_UNROLL + for (int dim = 0; dim < kRank; ++dim) { + if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { + return false; + } + } + return true; + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + TensorRef ref() const { + return TensorRef(this->data(), this->layout()); + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(this->data(), this->layout()); + } + + /// Returns a TensorView to const data + CUTLASS_HOST_DEVICE + ConstTensorView const_view() const { + return ConstTensorView(const_ref(), extent_); + } + + /// Returns a Tensor_view given location and size quantities + CUTLASS_HOST_DEVICE + TensorView subview( + TensorCoord extent, ///< extent of the resulting view + TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view + ) const { + + TensorView result(this->ref(), extent.clamp(extent_ - location)); + result.add_coord_offset(location); + return result; + } + + /// Returns the number of scalar elements needed to store tensor. + CUTLASS_HOST_DEVICE + size_t capacity() const { + return Base::layout().capacity(extent_); + } + + /// Returns a TensorView offset by a given amount + CUTLASS_HOST_DEVICE + TensorView operator+( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorView result(*this); + result.add_pointer_offset(this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorView& operator+=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(this->offset(b)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorView operator-( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorRef result(*this); + result.add_pointer_offset(-this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorView& operator-=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(-this->offset(b)); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE TensorView make_TensorView( + Element *ptr, + Layout const &layout, + typename Layout::TensorCoord const &extent) { + + return TensorView(ptr, layout, extent); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..6b8f7b47c49d75f0b000d134031ea169fcc6d2a6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h @@ -0,0 +1,302 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides and a pointer to tensor data. + + TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, + it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from + data storage and is therefore lightweight and may be embedded in larger tensor objects or + memory structures. + + See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to + linear memory. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref_planar_complex.h" +#include "cutlass/tensor_view.h" // cutlass::TensorView + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Data type of element stored within tensor + typename Element_, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename Layout_ +> +class TensorViewPlanarComplex : public TensorRefPlanarComplex { + public: + + /// Base tensor reference + using Base = cutlass::TensorRefPlanarComplex; + + /// Mapping function from logical coordinate to internal n-D array + using Layout = Layout_; + + /// TensorRef pointing to constant memory + using ConstTensorRef = typename Base::ConstTensorRef; + + /// Underlying TensorRef type + using TensorRef = Base; + + /// Data type of individual access + using Element = Element_; + + /// Reference type to an element + using Reference = Element &; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Coordinate in storage n-D array + using Stride = typename Layout::Stride; + + /// TensorView pointing to constant memory + using ConstTensorView = TensorViewPlanarComplex< + typename platform::remove_const::type const, + Layout>; + + /// TensorView pointing to non-constant memory + using NonConstTensorView = TensorViewPlanarComplex< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// View extent + TensorCoord extent_; + + public: + + // + // Methods + // + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) { + + } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + Element *ptr, ///< pointer to start of tensor + Layout const &layout, ///< layout object containing stride and mapping function + LongIndex imaginary_stride, ///< stride between real and imaginary part + TensorCoord const &extent ///< size of the view in logical coordinates + ): + Base(ptr, layout, imaginary_stride), extent_(extent) { + + } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + TensorRef const &ref, ///< pointer and layout object referencing a tensor + TensorCoord const &extent ///< logical size of tensor + ): + Base(ref), extent_(extent) { + + } + + /// Converting constructor from TensorRef to non-constant data. + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + NonConstTensorView const &view ///< TensorView to non-const data + ): + Base(view), extent_(view.extent_) { } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) { + Base::reset(ptr, layout, imaginary_stride); + this->resize(extent_); + } + + /// Changes the size of the view without affecting pointer or layout + CUTLASS_HOST_DEVICE + void resize(TensorCoord extent) { + this->extent_ = extent; + } + + /// Returns the extent of the view (the size along each logical dimension). + CUTLASS_HOST_DEVICE + TensorCoord const& extent() const { return extent_; } + + /// Returns the extent along a particular logical dimension. + CUTLASS_HOST_DEVICE + Index extent(int dim) const { return extent_.at(dim); } + + /// Determines whether a location is within a tensor + CUTLASS_HOST_DEVICE + bool contains(TensorCoord const& coord) const { + CUTLASS_PRAGMA_UNROLL + for (int dim = 0; dim < kRank; ++dim) { + if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { + return false; + } + } + return true; + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + Base ref() const { + return Base(this->data(), this->layout(), this->imaginary_stride()); + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(this->data(), this->layout()); + } + + /// Returns a TensorView to const data + CUTLASS_HOST_DEVICE + ConstTensorView const_view() const { + return ConstTensorView(const_ref(), extent_); + } + + /// Returns a Tensor_view given location and size quantities + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex subview( + TensorCoord extent, ///< extent of the resulting view + TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view + ) const { + + TensorViewPlanarComplex result(this->ref(), extent.clamp(extent_ - location)); + result.add_coord_offset(location); + return result; + } + + /// Returns the number of scalar elements needed to store tensor. + CUTLASS_HOST_DEVICE + size_t capacity() const { + return Base::layout().capacity(extent_); + } + + /// Returns a TensorView offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex operator+( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorViewPlanarComplex result(*this); + result.add_pointer_offset(this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex& operator+=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(this->offset(b)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex operator-( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorRef result(*this); + result.add_pointer_offset(-this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex& operator-=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(-this->offset(b)); + return *this; + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorView view_real() const { + return cutlass::TensorView(this->data(), this->layout(), extent_); + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorView view_imag() const { + return cutlass::TensorView(this->imaginary_data(), this->layout(), extent_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE TensorViewPlanarComplex make_TensorViewPlanarComplex( + Element *ptr, + Layout const &layout, + typename Layout::LongIndex imaginary_stride, + typename Layout::TensorCoord const &extent) { + + return TensorViewPlanarComplex(ptr, layout, imaginary_stride, extent); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h new file mode 100644 index 0000000000000000000000000000000000000000..7bc13e177f1d027fbba789367ac3f2ee5b748877 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h @@ -0,0 +1,479 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines a proxy class for storing Tensor Float 32 data type. +*/ +#pragma once + +#if defined(__CUDACC_RTC__) +#include "cutlass/floating_point_nvrtc.h" +#else +#include +#include +#include +#include // std::memcpy +#endif + +#include "cutlass/cutlass.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tensor Float 32 data type +struct alignas(4) tfloat32_t { + + // + // Data members + // + + /// Storage type + uint32_t storage; + + // + // Methods + // + private: + CUTLASS_HOST_DEVICE + static uint32_t float_to_storage(float s) { + #if defined(__CUDA_ARCH__) + uint32_t result = reinterpret_cast(s); + #else + uint32_t result; + std::memcpy(&result, &s, sizeof(float)); + #endif + return result; + } + + public: + /// Constructs from an unsigned int + CUTLASS_HOST_DEVICE + static tfloat32_t bitcast(uint32_t x) { + tfloat32_t h; + h.storage = x; + return h; + } + + /// Emulated rounding is fast in device code + CUTLASS_HOST_DEVICE + static tfloat32_t round_half_ulp_truncate(float const &s) { + uint32_t x = float_to_storage(s); + + #if defined(__CUDA_ARCH__) + if (::isfinite(s)) { + x += 0x1000u; + } + #else + if (std::isfinite(s)) { + x += 0x1000u; + } + #endif + + return tfloat32_t::bitcast(x); + } + + tfloat32_t() = default; + + /// Floating-point conversion - round toward nearest even + CUTLASS_HOST_DEVICE + explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).raw()) { } + + // Conversion from double (this rounds twice) + CUTLASS_HOST_DEVICE + explicit tfloat32_t(double x): tfloat32_t(float(x)) { } + + /// Integer conversion - round toward zero + CUTLASS_HOST_DEVICE + explicit tfloat32_t(int x) { + float flt = static_cast(x); + #if defined(__CUDA_ARCH__) + storage = reinterpret_cast(flt); + #else + std::memcpy(&storage, &flt, sizeof(storage)); + #endif + } + + // Conversion to float + CUTLASS_HOST_DEVICE + operator float() const { + + // Conversions to IEEE single-precision requires clearing dont-care bits + // of the mantissa. + unsigned bits = (storage & ~0x1fffu); + + #if defined(__CUDA_ARCH__) + return reinterpret_cast(bits); + #else + float flt; + std::memcpy(&flt, &bits, sizeof(flt)); + return flt; + #endif + } + + /// Converts to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(float(*this)); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(float(*this)); + } + + /// Casts to bool + CUTLASS_HOST_DEVICE + explicit operator bool() const { + return (float(*this) != 0.0f); + } + + /// Obtains raw bits + CUTLASS_HOST_DEVICE + uint32_t raw() const { + return storage; + } + + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return ((raw() & 0x80000000) != 0); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int((raw() >> 23) & 0x0ff); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return exponent_biased() - 127; + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(raw() & 0x7fffff); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool signbit(cutlass::tfloat32_t const& h) { + return h.signbit(); +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t abs(cutlass::tfloat32_t const& h) { + return cutlass::tfloat32_t::bitcast(h.raw() & 0x7fffffff); +} + +CUTLASS_HOST_DEVICE +bool isnan(cutlass::tfloat32_t const& h) { + return (h.exponent_biased() == 0x0ff) && h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isfinite(cutlass::tfloat32_t const& h) { + return (h.exponent_biased() != 0x0ff); +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t nan_tf32(const char*) { + // NVIDIA canonical NaN + return cutlass::tfloat32_t::bitcast(0x7fffffff); +} + +CUTLASS_HOST_DEVICE +bool isinf(cutlass::tfloat32_t const& h) { + return (h.exponent_biased() == 0x0ff) && !h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isnormal(cutlass::tfloat32_t const& h) { + return h.exponent_biased() && h.exponent_biased() != 0x0ff; +} + +CUTLASS_HOST_DEVICE +int fpclassify(cutlass::tfloat32_t const& h) { + int exp = h.exponent_biased(); + int mantissa = h.mantissa(); + if (exp == 0x0ff) { + if (mantissa) { + return FP_NAN; + } + else { + return FP_INFINITE; + } + } + else if (!exp) { + if (mantissa) { + return FP_SUBNORMAL; + } + else { + return FP_ZERO; + } + } + return FP_NORMAL; +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t sqrt(cutlass::tfloat32_t const& h) { +#if defined(__CUDACC_RTC__) + return cutlass::tfloat32_t(sqrtf(float(h))); +#else + return cutlass::tfloat32_t(std::sqrt(float(h))); +#endif +} + +CUTLASS_HOST_DEVICE +tfloat32_t copysign(tfloat32_t const& a, tfloat32_t const& b) { + + uint32_t a_mag = (a.raw() & 0x7fffffff); + uint32_t b_sign = (b.raw() & 0x80000000); + uint32_t result = (a_mag | b_sign); + + return tfloat32_t::bitcast(result); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Standard Library operations and definitions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace std { + +#if !defined(__CUDACC_RTC__) +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static std::float_denorm_style const has_denorm = std::denorm_present; + static bool const has_denorm_loss = true; + static std::float_round_style const round_style = std::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 19; + + /// Least positive value + static cutlass::tfloat32_t min() { return cutlass::tfloat32_t::bitcast(0x01); } + + /// Minimum finite value + static cutlass::tfloat32_t lowest() { return cutlass::tfloat32_t::bitcast(0xff7fffff); } + + /// Maximum finite value + static cutlass::tfloat32_t max() { return cutlass::tfloat32_t::bitcast(0x7f7fffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t epsilon() { return cutlass::tfloat32_t::bitcast(0x1000); } + + /// Returns smallest finite value + static cutlass::tfloat32_t round_error() { return cutlass::tfloat32_t(0.5f); } + + /// Returns smallest finite value + static cutlass::tfloat32_t infinity() { return cutlass::tfloat32_t::bitcast(0x7f800000); } + + /// Returns smallest finite value + static cutlass::tfloat32_t quiet_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t signaling_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t denorm_min() { return cutlass::tfloat32_t::bitcast(0x1); } +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace std + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Arithmetic operators +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool operator==(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) == float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator!=(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) != float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) < float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<=(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) <= float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) > float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>=(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) >= float(rhs); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) + float(rhs)); +} + + +CUTLASS_HOST_DEVICE +tfloat32_t operator-(tfloat32_t const& lhs) { + return tfloat32_t::bitcast(0x80000000 ^ lhs.raw()); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator-(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) - float(rhs)); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator*(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) * float(rhs)); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator/(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) / float(rhs)); +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator+=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) + float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator-=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) - float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator*=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) * float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator/=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) / float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator++(tfloat32_t & lhs) { + float tmp(lhs); + ++tmp; + lhs = tfloat32_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator--(tfloat32_t & lhs) { + float tmp(lhs); + --tmp; + lhs = tfloat32_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator++(tfloat32_t & lhs, int) { + tfloat32_t ret(lhs); + float tmp(lhs); + tmp++; + lhs = tfloat32_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator--(tfloat32_t & lhs, int) { + tfloat32_t ret(lhs); + float tmp(lhs); + tmp--; + lhs = tfloat32_t(tmp); + return ret; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// User-defined literals +// + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t operator "" _tf32(long double x) { + return cutlass::tfloat32_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) { + return cutlass::tfloat32_t(int(x)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..c338306132b9d9b2e42ff26759f7d1b3a7bc1ae3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a matrix object intended for storing data in registers and operations within + a CUDA thread. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/matrix_coord.h" + +namespace cutlass { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Per-thread matrix object storing a packed matrix +template < + typename Element, + int Rows, + int Columns, + typename Layout = layout::RowMajor +> +class Matrix : public Array { +public: + + // Verify layout refers to a rank=2 matrix. + static_assert( + Layout::kRank == 2, + "Layout type must refer to a rank=2 matrix"); + + /// Base type + using Base = Array; + + /// Element type + using Element = Element_; + + /// Number of rows + static int const kRows = Rows; + + /// Number of columns + static int const kColumns = Columns; + + /// Layout within the array + using Layout = Layout_; + + /// Reference type to an element + using Reference = Element &; + + /// Logical rank of tensor index space + static int const kRank = 2; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Stride type + using Stride = typename Layout::Stride; + + /// TensorRef to matrix object + using TensorRef = TensorRef; + + /// TensorRef to constant matrix object + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// TensorRef to matrix object + using TensorView = TensorView; + + /// TensorRef to constant matrix object + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Diagonal vector + using Diagonal = Vector; + +private: + + +public: + + // + // Methods + // + + /// Returns the size of the object + CUTLASS_HOST_DEVICE + static MatrixCoord extent() { + return make_Coord(kRows, kColumns); + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + static Layout layout() { + return Layout::packed(extent()); + } + + /// Ctor + CUTLASS_HOST_DEVICE + Matrix() { } + + /// Ctor + CUTLASS_HOST_DEVICE + Matrix(Diagonal const &diag) { + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + TensorRef ref() { + return TensorRef(this->data(), layout()); + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(this->data(), layout()); + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + TensorView view() { + return TensorView(ref(), extent()); + } + + /// Returns a TensorView to const data + CUTLASS_HOST_DEVICE + ConstTensorView const_view() const { + return ConstTensorView(const_ref(), extent()); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference at(MatrixCoord const& coord) const { + typename Base::size_type offset_(layout().offset(coord)); + return Base::at(offset_); + } + + /// Returns the number of scalar elements needed to store tensor. + CUTLASS_HOST_DEVICE + LongIndex capacity() const { + return LongIndex(Base::size()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Column vector defined as a matrix with exactly one column +template < + typename Element, + int Rows, + typename Layout = layout::ColumnMajor +> +using ColumnVector = Matrix; + +/// Row vector defined as a matrix with exactly one row +template < + typename Element, + int Columns, + typename Layout = layout::RowMajor +> +using RowVector = Matrix; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/trace.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/trace.h new file mode 100644 index 0000000000000000000000000000000000000000..803c72eca35a4cc3ee0712981942016f987f5b44 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/trace.h @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Helpers for optionally tracing through code when debugging. + + This file is to be included after all other headers. +*/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Tracing options +#ifndef CUTLASS_DEBUG_TRACE_LEVEL +#define CUTLASS_DEBUG_TRACE_LEVEL 0 +#endif + +#if CUTLASS_DEBUG_TRACE_LEVEL +#include +#include "cutlass/core_io.h" +#if defined(__CUDA_ARCH__) +#define CUTLASS_TRACE_HOST(x) +#else +#define CUTLASS_TRACE_HOST(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } +#endif +#else +#define CUTLASS_TRACE_HOST(x) +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp new file mode 100644 index 0000000000000000000000000000000000000000..41bc4786c7a8d148340a23bf1ce1db66f04f10b4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp @@ -0,0 +1,754 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing how threads are mapped to a given tile. +*/ + +#pragma once + +#include "cute/arch/mma_sm90_gmma.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +using namespace cute; + +template +constexpr auto +gmma_smem_transpose_or_passthrough() { + if constexpr (Transpose) { + if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_SW128_Atom{}; + } + else if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_SW64_Atom{}; + } + else if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_SW32_Atom{}; + } + else if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_INTER_Atom{}; + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported Layout_SW_Atom for B SMEM transposition"); + } + } + else { + return SmemLayoutAtom{}; + } +} + +template +constexpr auto +use_universal_transposition() { + if constexpr (sizeof(ElementType) == 1) { + return !cute::is_same_v, SmemCopyAtom>; + } + else if constexpr (sizeof(ElementType) == 4){ + // Only universal transposition can handle SW64 and Non swizzle SMEM layout + if constexpr (cute::is_same_v, SmemCopyAtom> || + cute::is_same_v, SmemCopyAtom>) { + return true; + } + else { + return false; + } + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported ElementType for B SMEM transposition"); + } +} + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class NoTranspositionOperandB { +public: + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + constexpr CUTLASS_HOST_DEVICE + NoTranspositionOperandB( + int, + int, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const&, + TensorTransposedSmemB const&, + int, int) { } + + CUTLASS_DEVICE void synchronize(int) { } + + CUTLASS_DEVICE void synchronize() { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const&, + TensorTransposedSmemB const&, + int) { } +}; + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class UniversalTranspositionOperandB { +public: + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + constexpr CUTLASS_HOST_DEVICE + UniversalTranspositionOperandB( + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) + : warp_idx(warp_idx_) + , warp_group_thread_idx(warp_group_thread_idx_) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage, int current_step) { + if (current_step > 0) { + return; + } + + constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static_assert(NumMathWarpGroup == 1 || + (!detail::use_universal_transposition() && NumMathWarpGroup == 2), + "Wrong math warp group number for TransposeB"); + constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. + + constexpr int BytesPerSmemSwizzleUnit = 16; + constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// Universal transposition, need warp_group sync between load and store. + /// The number of reg used depends on the input elementB. + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /* + In one copy step, a warp group would load WarpgroupTileSize * WarpgroupTileSize tile then store to transposed location. + In warp_group_tile, each warp holds Four WarpTileSize x WarpTileSize elements: + K + ------------ + | W0 W1 W2 W3 --- + | W0 W1 W2 W3 | + | W0 W1 W2 W3 | --> Copy Step 0 + | W0 W1 W2 W3 --- + .... + | W0 W1 W2 W3 --- + | W0 W1 W2 W3 | + | W0 W1 W2 W3 | --> Copy Step n + | W0 W1 W2 W3 --- + */ + static_assert((NumThreadsPerWarpGroup % WarpThreadShapeN == 0), "Unsupported warp thread layout."); + constexpr auto WarpgroupThreadLayout = make_layout(make_shape(Int{}, Int{})); + + // Get copy tile and partition to each thread + auto sB_tiled_copy = make_tiled_copy( + Copy_Atom{}, + WarpgroupThreadLayout, // thr_layout + Layout<_1>{} // val_layout + ); + static_assert(size(sB_tiled_copy) == size(TiledMma{}), "Wrong thread number in TiledCopy."); + + auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx); + Tensor tCsB = sB_thr_copy.partition_S( sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K) + Tensor tCsB_transposed = sB_thr_copy.partition_D(gmma_sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K) + + // Divide partitioned tile to limit register usage + constexpr int CopySteps = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + constexpr auto CopyTileShape = make_shape(size<0>(tCsB), Int< size<1>(tCsB) / CopySteps >{}, size<2>(tCsB)); + static_assert(size<1>(tCsB) % CopySteps == 0, "CopySteps must evenly divide rank 1 size of partitioned SMEM."); + + Tensor tCsB_copy_tile = zipped_divide(tCsB, CopyTileShape); + Tensor tCsB_copy_tile_transposed = zipped_divide(tCsB_transposed, CopyTileShape); + auto transpose_fragment = make_fragment_like(tCsB_copy_tile(_,_0{})); + + CUTLASS_PRAGMA_NO_UNROLL + for (int step = 0; step < CopySteps; ++step) { + copy(sB_tiled_copy, tCsB_copy_tile(_,step), transpose_fragment); + + // Make sure all elements are read before being overwritten + __syncthreads(); + + copy(sB_tiled_copy, transpose_fragment, tCsB_copy_tile_transposed(_,step)); + } + } + + CUTLASS_DEVICE void synchronize(int step) { + if (step == 0) { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + } + + CUTLASS_DEVICE void synchronize() { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage) { + + this->operator()(sB, gmma_sB, read_stage, 0); + synchronize(); + + } + +private: + const int warp_idx; + const int warp_group_thread_idx; +}; + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class AsyncTranspositionOperandB { +public: + + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + static constexpr int Steps = 2; + static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; + static_assert(NumMathWarpGroup <= 2, + "Wrong math warp group number for TransposeB"); + static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. + static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; + + static constexpr int BytesPerSmemSwizzleUnit = 16; + static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); + static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN; + static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); + + static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." ); + static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. + static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; + static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; + static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT. + static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits, + static constexpr int NumBitsPerStep = 3; + static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits) + static constexpr int NumBitsPerWarp = 12; + // Number of warp_group_tiles + static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, + "Copy size must evenly divide SMEM tile."); + static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + + static_assert(size<2>(typename TiledMma::AtomShape_MNK{}) <= WarpThreadShapeK, + "Need to be able to transpose first k-block in the first step"); + + constexpr CUTLASS_HOST_DEVICE + AsyncTranspositionOperandB( + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) + : warp_idx(warp_idx_) + , warp_group_thread_idx(warp_group_thread_idx_) + , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup) + , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) + , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage, int current_step) + { + if (current_step >= StepsPerWarpGroup) { + return; + } + + static constexpr auto WarpThreadLayout = make_layout(make_shape(Int{}, Int{})); + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// A warp group uses 2 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. + /// In each step, one warp would hold two warp_tiles. + /// Step 0: Step 1: + /// W0 W1 W2 W3 -- -- -- -- + /// W1 W0 -- -- -- -- W3 W2 + /// W2 -- -- -- -- W3 W0 W1 + /// W3 -- -- -- -- W2 W1 W0 + /// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// + /// Fully static coord LUT to avoid extra register use. + /// [warp_id][step][warp_tile][n / k] + /// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7 + /// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0 + /// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1 + /// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2 + /// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3 + /// + /// Encoding the coord of warp tile0 into two int64_t values. + /// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern. + /// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0. + /// The 2-step transposition and the 8-step transposition share the same encoding. + /// + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // Divide entire SMEM to multiple warp_tiles + constexpr auto WarpTileShape = make_shape(Int(), Int()); + Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape); + Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape); + + // Get copy tile + auto sB_tiled_copy = make_tiled_copy( + Copy_Atom{}, + WarpThreadLayout, // thr_layout + Layout<_1>{} // val_layout + ); + + static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy."); + auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx + + // Construct fragments for transposition + Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{})))); + decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = { + make_fragment_like(tmp_tCsB), + make_fragment_like(tmp_tCsB) + }; + + [[maybe_unused]] int step = current_step * NumMathWarpGroup; + if constexpr (NumMathWarpGroup == 2) { + // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. + step += warp_idx / (NumWarpsPerWarpGroup * 2); + } + + int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT >> (NumBitsPerStep * current_step); + int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT >> (NumBitsPerStep * current_step); + + if constexpr (NumMathWarpGroup == 2) { + tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + } + + // decoding the warp tile coord. + int warp_tile0_n, warp_tile0_k; + if constexpr (StepsPerWarpGroup <= NumStepsEncoded) { + warp_tile0_n = tmp_warp_tile_n_coord_LUT & MaskPerStep; + warp_tile0_k = tmp_warp_tile_k_coord_LUT & MaskPerStep; + } else { + warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group; + warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4; + } + + int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k; + int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n; + + CUTLASS_PRAGMA_UNROLL + for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) { + + static_assert(TilesPerWarp == 2); + + // [warp_tile][n/k] + const int warp_tile_coord[TilesPerWarp][2] = { + // n k + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0 + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1 + }; + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB = sB_thr_copy.partition_S( + flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + + copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); + } + + // Make sure elements in two 8x8 warp tiles are all consumed + __syncwarp(); + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB_transposed = sB_thr_copy.partition_D( + flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed); + } + + } // loop warp_group_tile + } + + CUTLASS_DEVICE void synchronize(int step) { + if (step < StepsPerWarpGroup) { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + } + + CUTLASS_DEVICE void synchronize() { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage) { + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < StepsPerWarpGroup; ++i) { + this->operator()(sB, gmma_sB, read_stage, i); + } + synchronize(); + + } +private: + const int warp_idx; + const int warp_group_thread_idx; + const int warp_idx_in_warp_group; + const int current_warp_tile_n_coord_LUT; + const int current_warp_tile_k_coord_LUT; +}; + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class AsyncTranspositionOperandB_1BElementB { +public: + + static_assert(sizeof(ElementB_) == 1); + + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + static constexpr int Steps = 8; + static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; + static_assert(NumMathWarpGroup <= 2, + "Wrong math warp group number for TransposeB"); + static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. + static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; + + static constexpr int BytesPerSmemSwizzleUnit = 16; + static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); + static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN; + static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); + + static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." ); + static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. + static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; + static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; + static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT. + static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits, + static constexpr int NumBitsPerStep = 3; + static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits) + static constexpr int NumBitsPerWarp = 12; + // Number of warp_group_tiles + static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, + "Copy size must evenly divide SMEM tile."); + static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + + constexpr CUTLASS_HOST_DEVICE + AsyncTranspositionOperandB_1BElementB( + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) + : warp_idx(warp_idx_) + , warp_group_thread_idx(warp_group_thread_idx_) + , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup) + , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) + , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage, int current_step) + { + if (current_step > 0) { + return; + } + + constexpr auto WarpThreadLayout = make_layout(make_shape(Int{}, Int{})); + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// A warp group uses 8 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. + /// Divide a warp_group_tile into 8x8 warp_tiles to further reduce the reg usage. + /// Step 0: Step 1: Step 2: Step 3: + /// W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// W1 W0 -- -- -- -- -- -- -- -- W3 W2 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// W2 -- -- -- -- -- -- -- -- W3 W0 W1 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// W3 -- -- -- -- -- -- -- -- W2 W1 W0 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W1 W0 -- -- -- -- -- -- -- -- W3 W2 + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W3 W0 W1 + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W2 W1 W0 + /// + /// Step 4: Step 5: Step 6: Step 7: + /// -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 + /// W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- + /// W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- + /// W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- + /// W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- + /// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// + /// Fully static coord LUT to avoid extra register use. + /// [warp_id][step][warp_tile][n / k] + /// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7 + /// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0 + /// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1 + /// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2 + /// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3 + /// + /// Encoding the coord of warp tile0 into two int64_t values. + /// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern. + /// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0. + /// The 2-step transposition and the 8-step transposition share the same encoding. + /// + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // Divide entire SMEM to multiple warp_tiles + constexpr auto WarpTileShape = make_shape(Int(), Int()); + Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape); + Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape); + + // Get copy tile + auto sB_tiled_copy = make_tiled_copy( + Copy_Atom{}, + WarpThreadLayout, // thr_layout + Layout<_1>{} // val_layout + ); + static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy."); + auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx + + // Construct fragments for transposition + Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{})))); + decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = { + make_fragment_like(tmp_tCsB), + make_fragment_like(tmp_tCsB) + }; + + CUTLASS_PRAGMA_NO_UNROLL + for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) { + int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT; + int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT; + constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; + + if constexpr (NumMathWarpGroup == 2) { + tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int step_per_warp_group = 0; step_per_warp_group < StepsPerWarpGroup; ++step_per_warp_group) { + // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. + int step = step_per_warp_group * NumMathWarpGroup + warp_idx / (NumWarpsPerWarpGroup * 2); + // decoding the warp tile coord. + int warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group; + int warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4; + int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k; + int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n; + + tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep; + tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep; + + static_assert(TilesPerWarp == 2); + + // [warp_tile][n/k] + const int warp_tile_coord[TilesPerWarp][2] = { + // n k + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0 + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1 + }; + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB = sB_thr_copy.partition_S( + flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + + copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); + } + + // Make sure elements in two 8x8 warp tiles are all consumed + __syncwarp(); + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB_transposed = sB_thr_copy.partition_D( + flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed); + } + } // lock step + } // loop warp_group_tile + } + + CUTLASS_DEVICE void synchronize(int step) { + if (step == 0) { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + } + + CUTLASS_DEVICE void synchronize() { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage) { + this->operator()(sB, gmma_sB, read_stage, 0); + synchronize(); + } + +private: + const int warp_idx; + const int warp_group_thread_idx; + const int warp_idx_in_warp_group; + const int current_warp_tile_n_coord_LUT; + const int current_warp_tile_k_coord_LUT; +}; + + +template< + class TiledMma, + class SmemLayoutB, + class SmemLayoutAtomB, + class ElementB, + bool TransposeB +> +constexpr CUTLASS_HOST_DEVICE +auto +make_transpose_operand_b( + int warp_idx, + int warp_group_thread_idx, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB, + cute::bool_constant) +{ + if constexpr (!TransposeB) { + return NoTranspositionOperandB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } + else if constexpr (use_universal_transposition()) { + return UniversalTranspositionOperandB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } + else if constexpr (sizeof(ElementB) == 1) { + return AsyncTranspositionOperandB_1BElementB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } + else { + return AsyncTranspositionOperandB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } +} + +}; // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace transform +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..265d2fe4367180b0c5c76f22df7d00f01dfb170e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp @@ -0,0 +1,303 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Transform Kernel Universal adapter +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cutlass/kernel_launch.h" +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::transform::device { + +//////////////////////////////////////////////////////////////////////////////// + +template +class TransformUniversalAdapter +{ +public: + using TransformKernel = GetUnderlyingKernel_t; + using Arguments = typename TransformKernel::Arguments; + using Params = typename TransformKernel::Params; + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + return TransformKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += TransformKernel::get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = TransformKernel::to_underlying_arguments(args, workspace); + return TransformKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return TransformKernel::get_grid_shape(params); + } + + + /// Initializes GEMM state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("TransformUniversalAdapter::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null") + << ", EnableCudaHostAdapter: " << (kEnableCudaHostAdapter ? "True" : "false")); + + // Initialize the workspace + Status status = TransformKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_ = TransformKernel::to_underlying_arguments(args, workspace); + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = TransformKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + } + return Status::kSuccess; + } + + static Status + run(Params& params, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + int32_t kernel_index = 0, + bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("TransformUniversalAdapter::run()"); + dim3 const block = TransformKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = TransformKernel::SharedStorageSize; + + Status launch_result{ Status::kSuccess }; + // Use extended launch API only for mainloops that use it + if constexpr (TransformKernel::ArchTag::kMinComputeCapability >= 90) { + // Currently only support 1x1x1 for transform kernel. + dim3 const cluster = {1,1,1}; + void* kernel_params[] = {¶ms}; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "TransformUniversalAdapter::run() does not support launching with PDL and a custom cuda adapter."); + return Status::kErrorInternal; + } + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + kernel_index); + CUTLASS_TRACE_HOST("Kernel Launch Result" << cutlassGetStatusString(launch_result)); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + void const* kernel = (void const*) device_kernel; + if constexpr (TransformKernel::ArchTag::kMinComputeCapability == 90) { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + } + } + } + else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; + + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + cutlass::kernel_launch(grid, block, smem_size, stream, params, launch_with_pdl); + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else if (cudaSuccess != result) { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << cudaGetErrorString(result)); + } + else if (Status::kSuccess != launch_result) { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << cutlassGetStatusString(launch_result)); + } + return Status::kErrorInternal; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + int32_t kernel_index = 0, + bool launch_with_pdl = false + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, kernel_index, launch_with_pdl); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::transform::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9c9d7589a309ebe6276bb564ac76a9e036bdd50a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp @@ -0,0 +1,223 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief Convolution filter format transformation kernel. +*/ + +#pragma once + +#include +#include + +#include "cutlass/coord.h" +#include "cutlass/arch/arch.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/int_tuple.hpp" +#include "cute/tensor.hpp" +#include "cute/config.hpp" + +namespace cutlass::transform::kernel { + +using namespace cute; + +enum class FilterFormat { + CKTRS, + CTRSK, + KTRSC +}; + +template < + FilterFormat SrcFormat, + FilterFormat DstFormat, + int NumDimensions, + class Element_, + int AlignmentBytes = 16 +> +struct ConvFilterFormatTransformer { + + using Element = Element_; + static_assert(SrcFormat == FilterFormat::CKTRS, "Currently only source format of CKTRS is supported"); + static_assert(DstFormat == FilterFormat::CTRSK || DstFormat == FilterFormat::KTRSC, "Currently only destination format of CTRSK/KTRSC is supported"); + static_assert(AlignmentBytes > 0 && AlignmentBytes % static_cast(sizeof(Element)) == 0, "Invalid alignment setting"); + + // In ktrsc order. + using FilterExtent = array; + + // Default cta tile shape: 32x32 + static constexpr auto CTATileShape = make_shape(Int<4 * AlignmentBytes / static_cast(sizeof(Element))>{}, Int<32>{}); + // Default thread layout: (4, 32) + static constexpr auto ThreadLayout = make_layout(make_shape(Int<4>{}, Int<32>{})); + + static constexpr uint32_t MaxThreadsPerBlock = 128; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + using ArchTag = arch::Sm90; + + // Default ctor + CUTLASS_HOST_DEVICE + ConvFilterFormatTransformer() {} + + struct Arguments { + const void *src_ptr; + void *dst_ptr; + FilterExtent filter_extent; + }; + + struct Params { + using TensorSrc = decltype(make_tensor(make_gmem_ptr(recast_ptr(nullptr)), make_layout(take<0,NumDimensions>(FilterExtent{})))); + using TensorDst = decltype(make_tensor(make_gmem_ptr(recast_ptr(nullptr)), make_layout(make_shape(int32_t(0), int32_t(0))))); + + TensorSrc src; + TensorDst dst; + }; + + struct SharedStorage { + /* empty, no smem needed */ + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static Status + can_implement(Arguments const& args) { + bool implementable = true; + // alignment rule + { + int contiguous_dim = DstFormat == FilterFormat::CTRSK ? args.filter_extent[0] : args.filter_extent[NumDimensions - 1]; + int align_element = AlignmentBytes / static_cast(sizeof(Element)); + + implementable &= (contiguous_dim % align_element == 0); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Alignment setting is invalid.\n"); + return Status::kInvalid; + } + } + + return Status::kSuccess; + } + + static size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static dim3 + get_block_shape() { + return dim3(size(shape(ThreadLayout)), 1, 1); + } + + static dim3 + get_grid_shape(Params const& params) { + auto dim_m = ceil_div(size<0>(shape(params.dst)), get<0>(CTATileShape)); + auto dim_n = ceil_div(size<1>(shape(params.dst)), get<1>(CTATileShape)); + + return dim3(dim_m, dim_n, 1); + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static Params + to_underlying_arguments(Arguments const& args, void* workspace) { + auto k = args.filter_extent[0]; + auto c = args.filter_extent[NumDimensions - 1]; + auto srt = reverse(take<1,NumDimensions - 1>(args.filter_extent)); + + // source shape (s,r,t,k,c) + auto shape_src = flatten(make_shape(srt, k, c)); + auto shape_dst = DstFormat == FilterFormat::CTRSK ? make_shape(k, c * product(srt)) : make_shape(c, k * product(srt)); + + auto src = make_tensor(make_gmem_ptr(recast_ptr(args.src_ptr)), make_layout(shape_src)); + auto dst = make_tensor(make_gmem_ptr(recast_ptr(args.dst_ptr)), make_layout(shape_dst)); + + return Params{src, dst}; + } + + CUTLASS_DEVICE + void operator()(Params const& params, char *smem_buf) { + // Tile the input tensor into blocks + auto block_coord = make_coord(blockIdx.x, blockIdx.y); + auto block_shape = make_shape(Int<4 * AlignmentBytes / static_cast(sizeof(Element))>{}, Int<32>{}); + // Default thread layout: (4, 32) + auto thread_layout = make_layout(make_shape(Int<4>{}, Int<32>{})); + auto vec_layout = make_layout(make_shape(Int(sizeof(Element))>{}, Int<1>{})); + + Tensor tile_D = local_tile(params.dst, block_shape, block_coord); + + // Construct tiled copy + using AccessType = cutlass::AlignedArray; + using Atom = Copy_Atom, Element>; + + auto tiled_copy = make_tiled_copy(Atom{}, thread_layout, vec_layout); + auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + Tensor thr_tile_D = thr_copy.partition_D(tile_D); + + // shape (s, r, t) + auto shape_trs = take<0, NumDimensions - 2>(shape(params.src)); + // strided_c = c for format CTRSK, strided_c = k for format KTRSC + auto strided_c = DstFormat == FilterFormat::CTRSK ? get(shape(params.src)) : get(shape(params.src)); + // shape (s, r, t, c) for format CTRSK and shape (s, r, t, k) for format KTRSC + auto shape_ctrs = append(shape_trs, strided_c); + auto srtc_coord = idx2crd(int(blockIdx.y * get<1>(block_shape) + threadIdx.x / size<0>(thread_layout)), shape_ctrs); + // index of k for format CTRSK and index of c for format KTRSC + auto n_layout = make_layout(make_shape(gridDim.x, size<0>(thread_layout)), make_stride(size<0>(block_shape), size<0>(vec_layout))); + int n_idx = n_layout(make_coord(blockIdx.x, threadIdx.x % size<0>(thread_layout))); + + // Fragment to load from S and store to D + auto frag = make_fragment_like(thr_tile_D); + // Predicate tensor. + Tensor thr_tile_P = make_tensor(shape(thr_tile_D)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(frag); ++i) { + auto srt_coord = take<0, NumDimensions - 2>(srtc_coord); + auto kc_coord = DstFormat == FilterFormat::CTRSK ? + make_coord(n_idx+i, get(srtc_coord)) : + make_coord(get(srtc_coord), n_idx+i); + auto coord = flatten(make_coord(srt_coord, kc_coord)); + thr_tile_P(i) = elem_less(coord, shape(params.src)); + if (thr_tile_P(i)) { + frag(i) = params.src(coord); + } + } + + // Copy from RMEM to GMEM + copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D); + } +}; + +} // namespace cutlass::transform::kernel diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..577c68c341c5c7a3d26c7209b2c40e309c65abee --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp @@ -0,0 +1,603 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Compress utils specific for SM90 structure sparse kernels +*/ + +#pragma once + +#include "cute/container/bit_field.hpp" // cute::bit_field +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v, cute::uint_bit_t +#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor +#include "cute/algorithm/cooperative_copy.hpp" // cute::cooperative_copy +#include "cutlass/arch/arch.h" // cutlass::arch::Sm90 +#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t +#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up +#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes +#include "cutlass/numeric_types.h" // cutlass::has_negative_zero_v +#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter + +namespace cutlass::transform::kernel { + +using namespace cute; + +template< + class ProblemShape_, + class ElementA_, + class LayoutATag_, + class SparseConfig_ +> +class SM90StructuredSparseCompressor { +public: + using SparseConfig = SparseConfig_; + using ProblemShape = ProblemShape_; + + // * EltA + using ElementA = ElementA_; + using ElementAUint = cute::uint_bit_t>; + using ElementAMma = typename SparseConfig::ElementAMma; + using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; + using ElementAMmaRawUnit = cute::uint_bit_t>; + using ElementASparsity = typename SparseConfig::ElementASparsity; + using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; + using ElementAUintCompressed = cute::sparse_elem; + using LayoutATag = LayoutATag_; + using LayoutA = LayoutATag; + using StrideA = cutlass::gemm::TagToStrideA_t; + + // * EltE + using ElementEMma = typename SparseConfig::ElementEMma; + using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; + using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; + // Data Type for storing one chunk's metadata + static constexpr int ElementEBitsPerChunk = typename SparseConfig::ElementEBitsPerChunk{}; + CUTE_STATIC_ASSERT(ElementEBitsPerChunk == 4, "ElementEBitsPerChunk is 4 for SM90"); + using ElementEChunk = cute::uint_bit_t; + CUTE_STATIC_ASSERT(cute::is_same_v, "ElementEChunk is uint4_t for SM90"); + using ElementESparsityPerChunk = Int / ElementEBitsPerChunk)>; + + // AtomE + using TensorEAtom = typename SparseConfig::TensorEAtom; + using TensorEAtomK = typename SparseConfig::TensorEAtomK; + using TensorEAtomM = typename SparseConfig::TensorEAtomM; + + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + // * Alignment + static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; + static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; + + // Required by `device_kernel` + static constexpr int MaxThreadsPerBlock = TensorEAtomM{}; + static constexpr int MinBlocksPerMultiprocessor = 1; + using ArchTag = arch::Sm90; + + struct SharedStorage { + ElementEMma cEsE[cute::size(TensorEAtom{})]; + ElementAUintCompressed cACsAC[cute::size(TensorEAtom{})]; + ElementAUint cAsA[cute::size(TensorEAtom{})]; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct TransformArguments { + void const* ptr_A{nullptr}; + StrideA dA{}; + void* ptr_ACompress{nullptr}; + void* ptr_E{nullptr}; + }; + + using TransformParams = TransformArguments; + + struct Arguments { + ProblemShape problem_shape{}; + TransformArguments transform{}; + KernelHardwareInfo hw_info{}; + }; + + struct Params { + ProblemShape problem_shape{}; + TransformParams transform{}; + KernelHardwareInfo hw_info{}; + void* workspace = nullptr; + }; + +public: + static Params + to_underlying_arguments(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::to_underlying_arguments()"); + return Params{{args.problem_shape}, + {args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E}, + {args.hw_info}, + workspace}; + } + + static Status + can_implement(Arguments const& args) { + auto [M, N, K, L] = args.problem_shape; + if (K % LogicalElemsAPerChunk != 0) { + CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size"); + return Status::kErrorInvalidProblem; + } + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::can_implement() (True)"); + return Status::kSuccess; + } + + static size_t + get_workspace_size(Arguments const& args) { + CUTLASS_UNUSED(args); + // Backward compatible with host compressor + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_workspace_size() (" << SharedStorageSize << ")"); + return SharedStorageSize; + } + + static Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + CUTLASS_UNUSED(args); + CUTLASS_UNUSED(workspace); + CUTLASS_UNUSED(stream); + CUTLASS_UNUSED(cuda_adapter); + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::initialize_workspace()"); + return Status::kSuccess; + } + + static dim3 + get_grid_shape(Params const& params) { + constexpr int MaxAlignmentM = cutlass::const_max(TensorEAlignmentM, TensorAAlignmentM); + constexpr int MaxAlignmentK = cutlass::const_max(TensorEAlignmentK, TensorAAlignmentK); + const auto [GemmM, GemmN, GemmK, GemmL] = params.problem_shape; + + const int GemmMAlignedMax = cutlass::round_up(GemmM, MaxAlignmentM); + const int GemmKAlignedMax = cutlass::round_up(GemmK, MaxAlignmentK); + + const int gridDim_X = cutlass::ceil_div(GemmMAlignedMax, TensorEAtomM{}); + const int gridDim_Y = cutlass::ceil_div(GemmKAlignedMax, TensorEAtomK{}); + const int gridDim_Z = GemmL; + + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_grid_shape() (" + << gridDim_X << ", " + << gridDim_Y << ", " + << gridDim_Z << ")"); + return dim3(gridDim_X, gridDim_Y, gridDim_Z); + } + + static dim3 + get_block_shape() { + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_block_shape() (" + << MaxThreadsPerBlock << ", " + << 1 << ", " + << 1 << ")"); + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTE_DEVICE + void + operator()(Params params, void* smem_buf = nullptr) { + run(params, smem_buf); + } + + CUTE_DEVICE + static void + run(Params params, void* smem_buf = nullptr) { + structure_sparse_compress(params, smem_buf); + } + +private: + + struct MetadataOneChunk1to2 { + + CUTE_DEVICE + void set_metadata_bits(int elt_log_idx, int elt_phy_idx) { + auto metadata_bits = [&]() -> uint8_t { + CUTLASS_ASSERT(elt_log_idx >= 0 && elt_log_idx < 2); + switch (elt_log_idx) { + case 0: + return 0b0100; + case 1: + return 0b1110; + default: + CUTE_GCC_UNREACHABLE; + } + }; + + storage_ |= (metadata_bits() << (4 * elt_phy_idx)); + } + + + CUTE_DEVICE + ElementEChunk storage() const { + return ElementEChunk{storage_}; + } + + private: + uint8_t storage_ = 0b0000; + }; + + struct MetadataOneChunk2to4{ + + CUTE_DEVICE + void set_metadata_bits(int elt_log_idx, int elt_phy_idx) { + auto metadata_bits = [&]() -> uint8_t { + CUTLASS_ASSERT(elt_log_idx >= 0 && elt_log_idx < 4); + switch (elt_log_idx) { + case 0: + return 0b00; + case 1: + return 0b01; + case 2: + return 0b10; + case 3: + return 0b11; + default: + CUTLASS_ASSERT(false); + CUTE_GCC_UNREACHABLE; + return 0b00; + } + }; + + storage_ |= (metadata_bits() << (2 * elt_phy_idx)); + } + + CUTE_DEVICE + ElementEChunk storage() const { + return ElementEChunk{storage_}; + } + + private: + uint8_t storage_ = 0b0000; + }; + + using MetadataOneChunk = cute::conditional_t; + +private: + + CUTE_DEVICE + static void + structure_sparse_compress(Params params, void* smem_buf) { + // * Input Params + auto [GemmM, GemmN, GemmK, GemmL] = params.problem_shape; + auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform; + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + [[maybe_unused]] const int gridDim_X = gridDim.x; + [[maybe_unused]] const int gridDim_Y = gridDim.y; + [[maybe_unused]] const int gridDim_Z = gridDim.z; + [[maybe_unused]] const int blockDim_X = blockDim.x; + + // * Global Tensor Layout + const cute::Layout layout_gA = make_layout(make_shape(GemmM, GemmK, GemmL), dA); + const cute::Layout layout_gAC = SparseConfig::fill_layoutA(params.problem_shape); + const cute::Layout layout_gE = SparseConfig::fill_layoutE(params.problem_shape); + + // * Construct Global Tensor + const cute::Tensor gA = make_tensor(make_gmem_ptr(cute::recast_ptr(ptr_A)), layout_gA); + cute::Tensor gAC_sparse = make_tensor(make_gmem_ptr(cute::recast_ptr(ptr_ACompress)), layout_gAC ); + cute::Tensor gAC = cute::recast(gAC_sparse); + cute::Tensor gE_sparse = make_tensor(make_gmem_ptr(cute::recast_ptr(ptr_E)), layout_gE); + cute::Tensor gE = cute::recast(gE_sparse); + + // * CTA Tensor Layout + using cAsA_layout_row = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{}), LayoutRight{})); + using cAsA_layout_col = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{}), LayoutLeft{})); + using cAsA_layout = cute::conditional_t, cAsA_layout_row, cAsA_layout_col>; + using cACsAC_layout = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{} / ElementASparsity{}), LayoutRight{})); + using cEsE_layout = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{} / ElementEMmaSparsity{}), LayoutRight{})); + + CUTE_STATIC_ASSERT(cute::is_static_v, "TensorEAtom needs to be static"); + CUTE_STATIC_ASSERT(cute::is_static_v, "cAsA_layout needs to be static"); + CUTE_STATIC_ASSERT(cute::is_static_v, "cACsAC_layout needs to be static"); + CUTE_STATIC_ASSERT(cute::is_static_v, "cEsE_layout needs to be static"); + + const int blockIdx_X = blockIdx.x; + const int blockIdx_Y = blockIdx.y; + const int blockIdx_Z = blockIdx.z; + const int threadIdx_X = threadIdx.x; + + // * Construct CTA Tensor + const auto cta_coord = make_coord(blockIdx_X, blockIdx_Y, blockIdx_Z); + cute::Tensor cAgA = cute::recast(local_tile(gA, shape(cAsA_layout{}), cta_coord)); + cute::Tensor cACgAC = cute::recast(local_tile(gAC, shape(cACsAC_layout{}), cta_coord)); + cute::Tensor cEgE = local_tile(gE, shape(cEsE_layout{}), cta_coord); + + cute::Tensor cAsA = cute::recast(make_tensor(make_smem_ptr(cute::recast_ptr(shared_storage.cAsA)), cAsA_layout{})); + cute::Tensor cACsAC = cute::recast(make_tensor(make_smem_ptr(cute::recast_ptr(shared_storage.cACsAC)), cACsAC_layout{})); + cute::Tensor cEsE = make_tensor(make_smem_ptr(cute::recast_ptr(shared_storage.cEsE)), cEsE_layout{}); + cute::Tensor cEsE_chunk = cute::recast(cEsE); + + // * Handle in unit of Chunk when compress + using OneChunkSizeA = Int; + using OneChunkSizeAC = Int; + using OneChunkSizeE = Int; + using NumOneChunkK = Int; + + cute::Tensor cAsA_log_chunk = logical_divide(cAsA, make_shape(_, OneChunkSizeA{})); + cute::Tensor cACsAC_log_chunk = logical_divide(cACsAC, make_shape(_, OneChunkSizeAC{})); + cute::Tensor cEsE_log_chunk = logical_divide(cEsE_chunk, make_shape(_, OneChunkSizeE{})); + + // * Corner Case Handle + const auto GemmM_within_Cta = (GemmM - blockIdx_X * TensorEAtomM{} > TensorEAtomM{}) ? TensorEAtomM{} : GemmM - blockIdx_X * TensorEAtomM{}; + const auto GemmK_within_Cta = ( (GemmK - blockIdx_Y * TensorEAtomK{} > TensorEAtomK{}) ? TensorEAtomK{} : GemmK - blockIdx_Y * TensorEAtomK{} ) / ElemsARawPerElementAMmaRaw; + const auto GemmK_NumOneChunk_within_Cta = GemmK_within_Cta / LogicalElemsAMmaRawPerChunk; + + const auto GemmMAlignedAC = cutlass::round_up(GemmM, TensorAAlignmentM); + const auto GemmKAlignedAC = cutlass::round_up(GemmK, TensorAAlignmentK); + const auto GemmMAlignedAC_within_Cta = (GemmMAlignedAC - blockIdx_X * TensorEAtomM{} > TensorEAtomM{}) ? TensorEAtomM{} : GemmMAlignedAC - blockIdx_X * TensorEAtomM{}; + const auto GemmKAlignedAC_within_Cta = ( (GemmKAlignedAC - blockIdx_Y * TensorEAtomK{} > TensorEAtomK{}) ? TensorEAtomK{} : GemmKAlignedAC - blockIdx_Y * TensorEAtomK{} ) / ElemsARawPerElementAMmaRaw; + + // * Clear CTA Smem Tensor + cooperative_clear(threadIdx_X, cACsAC); + cooperative_clear(threadIdx_X, cEsE); + + // * Input CTA Tensor G to S + if (GemmM_within_Cta == TensorEAtomM{} && GemmK_within_Cta == TensorEAtomK{}) { + copy_vec_pred(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta); + } + else { + copy_vec_pred(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta); + } + + // Construct a sign bit mask for handling negative zeros + ElementAMmaRawUnit sign_mask = ElementAMmaRawUnit{ 0 }; + if constexpr (has_negative_zero_v) { + ElementAMmaRawUnit one_sign_mask = static_cast(~(ElementAMmaRawUnit{ 1 } << (cute::sizeof_bits_v - 1))); + for (int i = 0; i < sizeof(ElementAMmaRawUnit) / sizeof(ElementAUint); ++i) { + sign_mask = static_cast((int32_t)sign_mask | (int32_t)one_sign_mask << (i * cute::sizeof_bits_v)); + } + } + + // * Compress + // cACsAC is always row major order + // TensorEAtomM threads perform the compression, each thread compress one row + const int row_i = threadIdx_X; + if (row_i < GemmM_within_Cta) { + + CUTE_UNROLL + for (int col_chunk_i = 0; col_chunk_i < NumOneChunkK{}; ++col_chunk_i) { + if (col_chunk_i < GemmK_NumOneChunk_within_Cta) { + // Compress is handled in unit of ElementAMmaRawUnit + cute::Tensor tAsA = cAsA_log_chunk(row_i, make_coord(_, col_chunk_i)); + cute::Tensor tACsAC = cACsAC_log_chunk(row_i, make_coord(_, col_chunk_i)); + cute::Tensor tEsE = cEsE_log_chunk(row_i, make_coord(_, col_chunk_i)); + + int non_zero_cnt = 0; + // None zero element indx + // e.g. + // 2:4 sparsity [x 0 0 x] + // non_zero_elt_log_idx = [0, 3] + int non_zero_elt_log_idx[OneChunkSizeAC{}] = { 0 }; + + // * Find None Zero Element Idx within Chunk + CUTE_UNROLL + for (int elt_log_idx = 0; elt_log_idx < OneChunkSizeA{}; ++elt_log_idx) { + ElementAMmaRawUnit elem_A = tAsA[elt_log_idx]; + + // Handle negative 0 + ElementAMmaRawUnit masked_elem_A = elem_A; + if constexpr (has_negative_zero_v) { + masked_elem_A = elem_A & sign_mask; + } + + if (masked_elem_A != ElementAMmaRawUnit{0}) { + non_zero_elt_log_idx[non_zero_cnt] = elt_log_idx; + tACsAC[non_zero_cnt] = elem_A; + non_zero_cnt++; + } + } + + // * Corner Case for 2:4 sparsity + if constexpr (cute::sizeof_bits_v < 32) { + // i.e. [0 0 0 x] -> [(0) 0 0 x] + if (non_zero_cnt == 1 && non_zero_elt_log_idx[0] == 3) { + tACsAC[1] = tACsAC[0]; + tACsAC[0] = ElementAMmaRawUnit{0}; + non_zero_elt_log_idx[0] = 0; + non_zero_elt_log_idx[1] = 3; + } + // i.e. [0 0 x 0] -> [0 0 x (0)] + // i.e. [0 x 0 0] -> [0 x 0 (0)] + // i.e. [x 0 0 0] -> [x 0 0 (0)] + else if (non_zero_cnt == 1) { + tACsAC[1] = ElementAMmaRawUnit{0}; + non_zero_elt_log_idx[1] = 3; + } + } + + // * Set Metadata Bits + MetadataOneChunk metadata_one_chunk; + CUTE_UNROLL + for (int elt_phy_idx = 0; elt_phy_idx < OneChunkSizeAC{}; elt_phy_idx++) { + metadata_one_chunk.set_metadata_bits(non_zero_elt_log_idx[elt_phy_idx], elt_phy_idx); + } + tEsE[0] = metadata_one_chunk.storage(); + + } + else { + break; + } + } + } + + // * Sync after Compress + __syncthreads(); + + // * Output Cta Tensor S to G + if (GemmM_within_Cta > 0 && GemmK_within_Cta > 0) { + constexpr int MaxVecBits = 128; // STG.128 + cute::cooperative_copy(threadIdx_X, cEsE, cEgE); + } + + if (GemmMAlignedAC_within_Cta == TensorEAtomM{} && GemmKAlignedAC_within_Cta == TensorEAtomK{}) { + copy_vec_pred(cACsAC, cACgAC, threadIdx_X, GemmMAlignedAC_within_Cta, (GemmKAlignedAC_within_Cta / ElementASparsity::value)); + } + else { + copy_vec_pred(cACsAC, cACgAC, threadIdx_X, GemmMAlignedAC_within_Cta, (GemmKAlignedAC_within_Cta / ElementASparsity::value)); + } + + } // end of structure_sparse_compress() + + template + CUTE_DEVICE + static void + cooperative_clear( + uint32_t const& tid, + TensorSrc dSrc) { + + auto dSrctSrc = local_partition(dSrc, make_layout(make_shape(NumThreads, _1{})), tid); + cute::clear(dSrctSrc); + + // Sync all thread data access + __syncthreads(); + } + + template + CUTE_DEVICE + static void + copy_vec_pred( + TensorSrc dSrc, + TensorDst dDst, + int threadIdx_X, + int valid_rows, + int valid_cols) { + + constexpr bool IsRowMajor = cute::is_same_v; + using Element = typename TensorSrc::element_type; + constexpr bool IsQmmaF6 = cute::sizeof_bits_v == 6; + + CUTE_STATIC_ASSERT(cute::is_static_v, "shape(dSrc) needs to be static"); + CUTE_STATIC_ASSERT(cute::is_static_v, "shape(dDst) needs to be static"); + CUTE_STATIC_ASSERT(cute::sizeof_bits_v == cute::sizeof_bits_v, + "dSrc and dDst need to have same element bit width"); + CUTE_STATIC_ASSERT(cute::size(dSrc) == cute::size(dDst), "dSrc and dDst need to have same size"); + + // ValueShape + using ValueShape = + cute::conditional_t, Int<1>>, + cute::conditional_t, Int<128 / sizeof_bits_v>>, + Shape>, Int<1>>> + >; + + constexpr int ValueShapeRows = shape<0>(ValueShape{}); + constexpr int ValueShapeCols = shape<1>(ValueShape{}); + + // ThreadShape + using ThreadShape = + cute::conditional_t, Int<1>>, + Shape, Int>>, + cute::conditional_t(dSrc) / ValueShapeCols)>, Int< (shape<1>(dSrc) / ValueShapeCols)>>, + Shape(dSrc) / ValueShapeRows)>, Int(dSrc) / ValueShapeRows)>>> + >; + + constexpr int ThreadShapeRows = shape<0>(ThreadShape{}); + constexpr int ThreadShapeCols = shape<1>(ThreadShape{}); + + const int threadIdx_X_row = threadIdx_X / ThreadShapeCols; + const int threadIdx_X_col = threadIdx_X % ThreadShapeCols; + + // Row Major + if constexpr (IsRowMajor) { + CUTE_UNROLL + for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) { + CUTE_UNROLL + for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) { + CUTE_UNROLL + for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) { + CUTE_UNROLL + for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) { + const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr; + const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr; + if constexpr ( (not pred) and (not IsQmmaF6) ) { + dDst(row_i, col_i) = dSrc(row_i, col_i); + } + else { + if (row_i < valid_rows && col_i < valid_cols) { + dDst(row_i, col_i) = dSrc(row_i, col_i); + } + } + } + } + } + } + } + // Col Major + else { + CUTE_UNROLL + for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) { + CUTE_UNROLL + for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) { + CUTE_UNROLL + for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) { + CUTE_UNROLL + for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) { + const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr; + const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr; + if constexpr ( (not pred) and (not IsQmmaF6) ) { + dDst(row_i, col_i) = dSrc(row_i, col_i); + } + else { + if (row_i < valid_rows && col_i < valid_cols) { + dDst(row_i, col_i) = dSrc(row_i, col_i); + } + } + } + } + } + } + } + + // Sync all thread data access + __syncthreads(); + } // end of copy_vec_pred() + +}; + +} // namespace cutlass::transform::kernel diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9f23535fea5df8df728b7c806d65f75f28c36aa3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp @@ -0,0 +1,325 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Compress utils for structured sparse kernels +*/ + +#pragma once + +#include // std::fill +#include // std::array +#include // std::mt19937 + +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor +#include "cutlass/arch/arch.h" // cutlass::arch::SmXY +#include "cutlass/detail/dependent_false.hpp" // cutlass::detail::dependent_false +#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t +#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes + +#include "cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp" + +namespace cutlass::transform::kernel { + +template< + class ProblemShape_, + class ElementA_, + class LayoutATag_, + class SparseConfig_ +> +class StructuredSparseCompressorUtility { +public: + using SparseConfig = SparseConfig_; + using ProblemShape = ProblemShape_; + + //* EltA + using ElementA = ElementA_; + using LayoutATag = LayoutATag_; + using StrideA = cutlass::gemm::TagToStrideA_t; + using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; + using ElementASparsity = typename SparseConfig::ElementASparsity; + using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; + + //* EltE + using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; + using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; + + //* AtomE + using TensorEAtom = typename SparseConfig::TensorEAtom; + using TensorEAtomK = typename SparseConfig::TensorEAtomK; + using TensorEAtomM = typename SparseConfig::TensorEAtomM; + + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + //* Alignment + static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; + static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; + + StructuredSparseCompressorUtility() = default; + + StructuredSparseCompressorUtility(ProblemShape problem, StrideA dA) { + set_problem_size(problem, dA); + } + + void set_problem_size(ProblemShape problem, StrideA dA_) { + M = cute::size<0>(problem); + K = cute::size<2>(problem); + L = cute::size<3>(problem); + + // The following three vars are logical elem count! + K_alignedA = round_up(K, TensorAAlignmentK); + M_alignedA = round_up(M, TensorAAlignmentM); + K_alignedE = round_up(K, TensorEAlignmentK); + M_alignedE = round_up(M, TensorEAlignmentM); + + dA = dA_; + } + + /** + * @brief Get the TensorE number of ElementE along K after alignment requirement + * + * @return int : number of ElementE (uint8_t) along K-dim + */ + int get_metadata_m_physical() const { + return M_alignedE; + } + + /** + * @brief Get the TensorE number of ElementE along M after alignment requirement + * + * @return int : number of ElementE (uint8_t) along M-dim + */ + int get_metadata_k_physical() const { + return K_alignedE / ElementEMmaSparsity{}; + } + + /** + * @brief Get the TensorACompressed number of ElementA along K after alignment requirement + * + * @return int : number of ElementA along K-dim + */ + int get_tensorA_k_physical() const { + return K_alignedA / ElementASparsity{}; + } + + /** + * @brief Get the TensorACompressed number of ElementA along M after alignment requirement + * + * @return int : number of ElementA along M-dim + */ + int get_tensorA_m_physical() const { + return M_alignedA; + } + + /** + * @brief Get the TensorACompressed Bytes + * + * @return uint64_t bytes + */ + uint64_t get_compressed_tensor_A_bytes() const { + const auto tensor_a_comp_num_elt_a = get_tensorA_m_physical() * get_tensorA_k_physical() * L; + const auto tensor_a_comp_bytes = cutlass::bits_to_bytes(tensor_a_comp_num_elt_a * cute::sizeof_bits_v); + return tensor_a_comp_bytes; + } + + /** + * @brief Get the TensorA Bytes + * + * @return uint64_t bytes + */ + uint64_t get_raw_tensor_A_bytes() const { + const auto tensor_a_num_elt_a = uint64_t(M) * uint64_t(K) * uint64_t(L); + const auto tensor_a_bytes = cutlass::bits_to_bytes(tensor_a_num_elt_a * cute::sizeof_bits_v); + return tensor_a_bytes; + } + + /** + * @brief Get the TensorE Bytes + * + * @return uint64_t bytes + */ + uint64_t get_tensor_E_bytes() const { + const auto tensor_e_num_elt_a = uint64_t(get_metadata_m_physical()) * uint64_t(get_metadata_k_physical()) * uint64_t(L); + const auto tensor_e_bytes = cutlass::bits_to_bytes(tensor_e_num_elt_a * cute::sizeof_bits_v); + return tensor_e_bytes; + } + + constexpr auto fill_layoutA_from_compressor() const { + return SparseConfig::fill_layoutA(cute::make_tuple(M,_1{},K,L)); + } + + constexpr auto fill_layoutE_from_compressor() const { + return SparseConfig::fill_layoutE(cute::make_tuple(M,_1{},K,L)); + } + + void structure_sparse_zero_mask_fill(void* host_a_ptr, uint64_t seed) { + + constexpr int ChunkSize = LogicalElemsAMmaRawPerChunk; + using ChunkElement = cute::uint_bit_t>; + + cute::Tensor gA_eltA = cute::make_tensor( + cute::recast_ptr(host_a_ptr), + cute::make_layout(make_shape(M, K, L), dA)); + + // Input TensorA is handled in unit of ElementAMmaRaw instead of ElementA + cute::Tensor gA = cute::recast(gA_eltA); + + // Extract out the Chunk from K-mode + Tensor gA_chunk = cute::zipped_divide(gA, cute::Shape<_1,cute::Int>{}); // (Chunk, Rest) + + // Half of the data is zero to indicate sparsityA = 2 + std::array nnzb_indicator{}; + for (size_t i = 1; i < nnzb_indicator.size(); i += 2) { + nnzb_indicator.at(i) = 1; + } + + std::mt19937 rng(seed); + auto rest_shape = cute::shape<1>(gA_chunk); + for (auto iter = cute::make_coord_iterator(rest_shape); iter != cute::ForwardCoordIteratorSentinel{}; ++iter) { + std::shuffle(nnzb_indicator.begin(), nnzb_indicator.end(), rng); + for (int c = 0; c < size<0>(gA_chunk); ++c) { // for each elem within chunk + if (nnzb_indicator[c] == 0) { + gA_chunk(c, *iter) = ChunkElement{0}; + } + } // end of within chunk + } // end of chunk_idx + } + + int M{-1}; + int K{-1}; + int L{-1}; + StrideA dA{}; + +private: + int K_alignedA{-1}; + int M_alignedA{-1}; + int K_alignedE{-1}; + int M_alignedE{-1}; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class ElementA, + class LayoutATag, + class SparseConfig, + class ArchTag +> +struct StructuredSparseCompressorSelector { + static_assert(cutlass::detail::dependent_false, + "Could not select a structured sparse compressor for given parameters."); +}; + +template< + class ProblemShape, + class ElementA, + class LayoutATag, + class SparseConfig +> +struct StructuredSparseCompressorSelector< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig, + arch::Sm90> { + using Compressor = SM90StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig + >; +}; + +template< + class ProblemShape, + class ElementA, + class LayoutATag, + class SparseConfig +> +struct StructuredSparseCompressorSelector< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig, + arch::Sm100> { + using Compressor = SM90StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig + >; +}; + +template< + class ProblemShape, + class ElementA, + class LayoutATag, + class SparseConfig +> +struct StructuredSparseCompressorSelector< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig, + arch::Sm120> { + using Compressor = SM90StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig + >; +}; + +template< + class ProblemShape, + class ElementA, + class LayoutATag, + class SparseConfig, + class ArchTag +> +using StructuredSparseCompressor = typename StructuredSparseCompressorSelector< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig, + ArchTag +>::Compressor; + +} // End namespace cutlass::transform::kernel diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h new file mode 100644 index 0000000000000000000000000000000000000000..ef553aab2043775758c2a87d422456dc5cca2426 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h @@ -0,0 +1,926 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing how threads are mapped to a given tile. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { + +//////////////////////////////////////////////////////////////////////////////// + +/// Strip-mines a pitch-linear tile among a given number of threads, first along +/// the contiguous dimension then along the strided dimension. +/// +/// The tile must be divisible by the thread count such that all threads may +/// execute the same number of iterations with the same delta to exhaustively +/// cover the tile. +/// +/// This class satisfies the "RegularThreadMapping" concept. +/// +/// This ThreadMap is used by SIMT kernels and operand E of the sparse tensor +/// kernels. +template < + typename Shape_, + int Threads, + int ElementsPerAccess = 1 +> +struct PitchLinearStripminedThreadMap { + + /// Tensor coordinate + using TensorCoord = layout::PitchLinearCoord; + + /// Tile shape + using Shape = Shape_; + + /// Number of threads total + static int const kThreads = Threads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ElementsPerAccess; + + /// Shape of access by each thread + using ThreadAccessShape = layout::PitchLinearShape; + + /// Internal implementation details + struct Detail { + + static_assert(!(Shape::kContiguous % kElementsPerAccess), ""); + + /// Shape of the tile in units of vectors + using ShapeVec = layout::PitchLinearShape< + Shape::kContiguous / kElementsPerAccess, + Shape::kStrided + >; + + static_assert((Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || + (!(kThreads % ShapeVec::kContiguous)), + "Shape must be divisible by number of iterations of each thread."); + }; + + /// Number of iterations by each thread + using Iterations = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape< + 1, + // Redo the comparison here to work around divide by zero compiler + // error. The compiler evaluates both path of platform::conditional. + (Threads >= Detail::ShapeVec::kContiguous + ? (Detail::ShapeVec::kStrided + (kThreads / Detail::ShapeVec::kContiguous - 1)) / + (kThreads / Detail::ShapeVec::kContiguous) + : 0)>, + layout::PitchLinearShape>::type; + + + /// Interval between accesses along each dimension of the tensor's logical coordinate space + /// (in units of Elements) + using Delta = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape< + 1, + kThreads / Detail::ShapeVec::kContiguous + >, + layout::PitchLinearShape< + kThreads * kElementsPerAccess, + 1 + > + >::type; + + /// Shape of the tile in units of vectors + using StorageShape = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape, + layout::PitchLinearShape>::type; + + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space + /// (in units of Elements) + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + return TensorCoord( + (thread_id % Detail::ShapeVec::kContiguous) * kElementsPerAccess, + thread_id / Detail::ShapeVec::kContiguous); + } +}; + +/// This ThreadMap is used by GEMV +template < + typename Shape, + int Threads, + int ElementsPerAccess = 1 +> +struct PitchLinearTilePolicyStripminedThreadContiguous +{ + static_assert((Shape::kContiguous % (Threads * ElementsPerAccess)) == 0, + "Contiguous shape must divide number of threads"); + + using TensorCoord = layout::PitchLinearCoord; + + static int const kThreads = Threads; + static int const kElementsPerAccess = ElementsPerAccess; + + using Iterations = layout::PitchLinearShape< + Shape::kContiguous / (kThreads * kElementsPerAccess), + Shape::kStrided>; + + using Delta = layout::PitchLinearShape<1, 1>; + + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) + { + return TensorCoord(thread_id * Iterations::kContiguous * kElementsPerAccess, 0); + } +}; + +template < + typename Shape, + int Threads, + int ElementsPerAccess = 1 +> +struct PitchLinearTilePolicyStripminedThreadStrided +{ + static_assert((Shape::kStrided % Threads == 0), + "Strided shape must divide number of threads"); + + using TensorCoord = layout::PitchLinearCoord; + + static int const kThreads = Threads; + static int const kElementsPerAccess = ElementsPerAccess; + + using Iterations = layout::PitchLinearShape< + Shape::kContiguous / kElementsPerAccess, + Shape::kStrided / kThreads>; + + using Delta = layout::PitchLinearShape<1, 1>; + + using ShapeVec = Shape; + + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) + { + + return TensorCoord(0, thread_id * Iterations::kStrided); + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous +/// elements. +/// +/// This ThreadMap is used by tensor core kernels. +template < + typename Shape_, + int Threads, + typename WarpThreadArrangement_, + int ElementsPerAccess = 1 +> +struct PitchLinearWarpRakedThreadMap { + + /// Tensor coordinate + using TensorCoord = layout::PitchLinearCoord; + + /// Tile shape + using Shape = Shape_; + + /// Number of threads total + static int const kThreads = Threads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ElementsPerAccess; + + /// Shape of access by each thread + using ThreadAccessShape = layout::PitchLinearShape; + + /// Internal details made public to facilitate introspection + struct Detail { + + /// Fixed arrangement of threads within a warp (units of threads). + using WarpThreadArrangement = WarpThreadArrangement_; + + /// Number of threads per warp + static int const kWarpSize = WarpThreadArrangement::kCount; + + /// Number of participating warps + static int const kWarpCount = kThreads / kWarpSize; + + static_assert( + !(Shape::kContiguous % kElementsPerAccess), + "Shape must be divisible by vector length."); + + /// Compute the 'shape' of the overall tile in units of vectors + using ShapeInAccesses = layout::PitchLinearShape< + Shape::kContiguous / kElementsPerAccess, + Shape::kStrided + >; + + static_assert( + !(ShapeInAccesses::kContiguous % WarpThreadArrangement::kContiguous), + "ShapeInAccesses must be divisible by WarpThreadArrangement."); + + static_assert( + !(ShapeInAccesses::kStrided % WarpThreadArrangement::kStrided), + "ShapeInAccesses must be divisible by WarpThreadArrangement."); + + // compute number of warp-level accesses total + using WarpAccessIterations = layout::PitchLinearShape< + ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, + ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided + >; + + // Divide it into the number of warps, first partitioning the strided dimension then the + // contiguous. + static int const kWarpsStrided = + (WarpAccessIterations::kStrided >= kWarpCount + ? kWarpCount + : WarpAccessIterations::kStrided); + + static int const kWarpsContiguous = + (kWarpCount > WarpAccessIterations::kStrided + ? kWarpCount / kWarpsStrided + : 1); + + /// Arrangement of warps within a threadblock-scoped tile + using WarpArrangement = layout::PitchLinearShape< + kWarpsContiguous, kWarpsStrided + >; + }; + + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = layout::PitchLinearShape< + Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, + Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided + >; + + static_assert(Iterations::kCount, + "Number of iterations must be non-zero"); + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = layout::PitchLinearShape< + Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, + Detail::WarpThreadArrangement::kStrided + >; + + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + int warp_id = (thread_id / Detail::kWarpSize); + int lane_id = (thread_id % Detail::kWarpSize); + + // + // compute warp-level offset + // + + // This is the shape of the entire area covered by a warp's memory access (in units of vectors) + layout::PitchLinearCoord warp_footprint{ + Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, + Detail::WarpThreadArrangement::kStrided * Iterations::kStrided + }; + + // This is the offset of a specific warp (in units of vectors) + layout::PitchLinearCoord warp_offset{ + (warp_id % Detail::kWarpsContiguous), + (warp_id / Detail::kWarpsContiguous) + }; + + // This is the offset of a specific thread within a warp (units of vectors) + layout::PitchLinearCoord thread_offset_in_warp{ + lane_id % Detail::WarpThreadArrangement::kContiguous, + lane_id / Detail::WarpThreadArrangement::kContiguous + }; + + // This is the offset of a thread within a threadblock tile (units of vectors) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + warp_footprint * warp_offset + thread_offset_in_warp; + + // This is the offset of a thread within a threadblock tile (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ + thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, + thread_offset_in_threadblock_tile_vec.strided() + }; + + return thread_offset_in_threadblock_tile_base; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous +/// elements. Warps are arranged based on a stride. +/// +/// This ThreadMap is used by tensor core kernels for NCxHWx layout. +template < + typename Shape_, + int Threads, + typename WarpThreadArrangement_, + int ElementsPerAccess = 1 +> +struct PitchLinearStridedWarpRakedThreadMap { + + /// Tensor coordinate + using TensorCoord = layout::PitchLinearCoord; + + /// Tile shape + using Shape = Shape_; + + /// Number of threads total + static int const kThreads = Threads; + + using WarpThreadArrangement = WarpThreadArrangement_; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ElementsPerAccess; + + /// Base ThreadMap + using BaseThreadMap = PitchLinearWarpRakedThreadMap< + Shape, + kThreads, + WarpThreadArrangement, + kElementsPerAccess + >; + + /// Shape of access by each thread + using ThreadAccessShape = typename BaseThreadMap::ThreadAccessShape; + + + struct Detail { + + using WarpThreadArrangement = WarpThreadArrangement_; + + using WarpAccessIterations = typename BaseThreadMap::Detail::WarpAccessIterations; + + static int const kWarpSize = BaseThreadMap::Detail::kWarpSize; + + static int const kWarpCount = BaseThreadMap::Detail::kWarpCount; + + using ShapeInAccesses = typename BaseThreadMap::Detail::ShapeInAccesses; + + // Divide it into the number of warps, first partitioning the contiguous dimension then the + // stride. + static int const kWarpsContiguous = + (WarpAccessIterations::kContiguous >= kWarpCount + ? kWarpCount + : WarpAccessIterations::kContiguous); + + static int const kWarpsStrided = + (kWarpCount > WarpAccessIterations::kContiguous + ? kWarpCount / kWarpsContiguous + : 1); + + /// Arrangement of warps within a threadblock-scoped tile + using WarpArrangement = layout::PitchLinearShape< + kWarpsContiguous, kWarpsStrided + >; + + }; + + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = layout::PitchLinearShape< + Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, + Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided + >; + + static_assert(Iterations::kCount, + "Number of iterations must be non-zero"); + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = typename BaseThreadMap::Delta; + + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + int warp_id = (thread_id / Detail::kWarpSize); + int lane_id = (thread_id % Detail::kWarpSize); + + // + // compute warp-level offset + // + + // This is the shape of the entire area covered by a warp's memory access (in units of vectors) + layout::PitchLinearCoord warp_footprint{ + Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, + Detail::WarpThreadArrangement::kStrided * Iterations::kStrided + }; + + // This is the offset of a specific warp (in units of vectors) + layout::PitchLinearCoord warp_offset{ + (warp_id % Detail::kWarpsContiguous), + (warp_id / Detail::kWarpsContiguous) + }; + + // This is the offset of a specific thread within a warp (units of vectors) + layout::PitchLinearCoord thread_offset_in_warp{ + lane_id % Detail::WarpThreadArrangement::kContiguous, + lane_id / Detail::WarpThreadArrangement::kContiguous + }; + + // This is the offset of a thread within a threadblock tile (units of vectors) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + warp_footprint * warp_offset + thread_offset_in_warp; + + // This is the offset of a thread within a threadblock tile (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ + thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, + thread_offset_in_threadblock_tile_vec.strided() + }; + + return thread_offset_in_threadblock_tile_base; + } + + +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Transpose the existing ThreadMap. For example, interleaved layout is like +/// congruous in the global memory and crosswise in the shared memory. We need +/// to transpose the coordinates between two. + +template +struct TransposePitchLinearThreadMap { + /// Underlying ThreadMap + using ThreadMap = ThreadMap_; + + /// Tensor coordinate + using TensorCoord = typename ThreadMap::TensorCoord; + + /// Tile shape + using Shape = typename ThreadMap::Shape; + + /// Number of threads total + static int const kThreads = ThreadMap::kThreads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + /// Shape of access by each thread + using ThreadAccessShape = layout::PitchLinearShape; + + /// Internal details made public to facilitate introspection + struct Detail { + /// Fixed arrangement of threads within a warp (units of threads). + using WarpThreadArrangement = WarpThreadArrangement_; + + /// Number of threads per warp + static int const kWarpSize = WarpThreadArrangement::kCount; + + /// Number of participating warps + static int const kWarpCount = kThreads / kWarpSize; + + static_assert(!(Shape::kContiguous % kElementsPerAccess), + "Shape must be divisible by vector length."); + + /// Arrangement of warps within a threadblock-scoped tile + using WarpArrangement = + layout::PitchLinearShape; + }; + + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = + layout::PitchLinearShape; + + static_assert(Iterations::kContiguous == 1, + "Contiguous iteration has to be one to reuse the same shared store function with those that don't need transpose"); + + static_assert(Iterations::kCount, "Number of iterations must be non-zero"); + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = + layout::PitchLinearShape; + + /// Maps thread ID to a coordinate offset within the tensor's logical + /// coordinate space Note this is slightly different from the one of + /// PitchLinearWarpRakedThreadMap. + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + int warp_id = (thread_id / Detail::kWarpSize); + int lane_id = (thread_id % Detail::kWarpSize); + + // + // compute warp-level offset + // + + // This is the shape of the entire area covered by a warp's memory access + // (in units of vectors) + layout::PitchLinearCoord warp_footprint{ + Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, + Detail::WarpThreadArrangement::kStrided * Iterations::kStrided}; + + // This is the offset of a specific warp (in units of vectors) + // Note the order of / and %. Also the 2nd operand is kStrided. + layout::PitchLinearCoord warp_offset{ + (warp_id / Detail::WarpArrangement::kStrided), + (warp_id % Detail::WarpArrangement::kStrided)}; + + // This is the offset of a specific thread within a warp (units of vectors) + layout::PitchLinearCoord thread_offset_in_warp{ + lane_id % Detail::WarpThreadArrangement::kContiguous, + lane_id / Detail::WarpThreadArrangement::kContiguous}; + + // This is the offset of a thread within a threadblock tile (units of + // vectors) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + warp_footprint * warp_offset + thread_offset_in_warp; + + // This is the offset of a thread within a threadblock tile (units of + // elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ + thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, + thread_offset_in_threadblock_tile_vec.strided()}; + + return thread_offset_in_threadblock_tile_base; + } +}; + +template +struct TransposePitchLinearThreadMapSimt { + /// Underlying ThreadMap + using ThreadMap = ThreadMap_; + + /// Tensor coordinate + using TensorCoord = typename ThreadMap::TensorCoord; + + /// Tile shape + using Shape = typename ThreadMap::Shape; + + /// Number of threads total + static int const kThreads = ThreadMap::kThreads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static_assert(kElementsPerAccess == 1 , "Simt transpose requires elements per access to be 1"); + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = + layout::PitchLinearShape; + + static_assert(Iterations::kCount, "Number of iterations must be non-zero"); + + static_assert(Iterations::kStrided == 1, + "Strided iteration has to be one to reuse the same shared store function with those that don't need transpose"); + + /// Shape of access by each thread + using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = + layout::PitchLinearShape; + + + /// Maps thread ID to a coordinate offset within the tensor's logical + /// coordinate space Note this is slightly different from the one of + /// PitchLinearWarpRakedThreadMap. + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + TensorCoord coord = ThreadMap::initial_offset(thread_id); + + return TensorCoord( + coord.strided(), + coord.contiguous() + ); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Policy defining a warp-striped arrangement. This partitions a tile into vectorized memory +/// accesses performed by each warp then distributes warps across them. Warps are striped in the +/// strided dimension and raked across the contiguous dimension. +template < + typename Shape_, /// Overall shape to partition in units of elements + int Threads, /// Number of partiticipation threads + typename WarpThreadArrangement_, /// Describes the shape of one memory access per warp + int ElementsPerAccess = 1 /// Number of elements accessed by each thread per memory operation (i.e. vector size) +> +struct PitchLinearWarpStripedThreadMap { + + /// Tensor coordinate + using TensorCoord = layout::PitchLinearCoord; + + /// Tile shape + using Shape = Shape_; + + /// Number of threads total + static int const kThreads = Threads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ElementsPerAccess; + + /// Shape of access by each thread + using ThreadAccessShape = layout::PitchLinearShape; + + /// Internal details made public to facilitate introspection + struct Detail { + + /// Fixed arrangement of threads within a warp (units of threads). + using WarpThreadArrangement = WarpThreadArrangement_; + + /// Number of threads per warp + static int const kWarpSize = WarpThreadArrangement::kCount; + + /// Number of participating warps + static int const kWarpCount = kThreads / kWarpSize; + + static_assert( + !(Shape::kContiguous % kElementsPerAccess), + "Shape must be divisible by vector length."); + + /// Compute the 'shape' of the overall tile in units of vectors + using ShapeInAccesses = layout::PitchLinearShape< + Shape::kContiguous / kElementsPerAccess, + Shape::kStrided + >; + + // compute number of warp-level accesses total + using WarpAccessIterations = layout::PitchLinearShape< + ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, + ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided + >; + + // Divide it into the number of warps, first partitioning the strided dimension then the + // contiguous. + static int const kWarpsStrided = + (WarpAccessIterations::kStrided >= kWarpCount + ? kWarpCount : (kWarpCount / WarpAccessIterations::kStrided)); + + static int const kWarpsContiguous = + (kWarpCount > WarpAccessIterations::kStrided ? + WarpAccessIterations::kContiguous / kWarpsStrided : 1); + + /// Arrangement of warps within a threadblock-scoped tile + using WarpArrangement = layout::PitchLinearShape< + kWarpsContiguous, kWarpsStrided + >; + }; + + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = layout::PitchLinearShape< + Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, + Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided + >; + + static_assert(Iterations::kCount, + "Number of iterations must be non-zero"); + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = layout::PitchLinearShape< + Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, + Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided + >; + + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + int warp_id = (thread_id / Detail::kWarpSize); + int lane_id = (thread_id % Detail::kWarpSize); + + // + // compute warp-level offset + // + + // This is the shape of the entire area covered by a warp's memory access (in units of vectors) + layout::PitchLinearCoord warp_footprint{ + Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, + Detail::WarpThreadArrangement::kStrided + }; + + // This is the offset of a specific warp (in units of vectors) + layout::PitchLinearCoord warp_offset{ + (warp_id % Detail::kWarpsContiguous), + (warp_id / Detail::kWarpsContiguous) + }; + + // This is the offset of a specific thread within a warp (units of vectors) + layout::PitchLinearCoord thread_offset_in_warp{ + lane_id % Detail::WarpThreadArrangement::kContiguous, + lane_id / Detail::WarpThreadArrangement::kContiguous + }; + + // This is the offset of a thread within a threadblock tile (units of vectors) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + warp_footprint * warp_offset + thread_offset_in_warp; + + // This is the offset of a thread within a threadblock tile (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ + thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, + thread_offset_in_threadblock_tile_vec.strided() + }; + + return thread_offset_in_threadblock_tile_base; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Strip-mines a pitch-linear tile among a given number of threads, first along the contiguous +/// dimension then along the strided dimension, while each thread access a 2D thread-tile. +/// +/// The tile must be divisible by the thread count such that all threads may execute the same +/// number of iterations with the same delta to exhaustively cover the tile. +/// +/// This class satisfies the "RegularThreadMapping" concept. +template < + typename Shape_, + int Threads, + typename ThreadTileShape +> +struct PitchLinear2DThreadTileStripminedThreadMap; + + +template < + typename Shape_, + int Threads +> +struct PitchLinear2DThreadTileStripminedThreadMap >{ + + /// Tensor coordinate + using TensorCoord = layout::PitchLinearCoord; + + /// Tile shape + using Shape = Shape_; + + /// Access Shape of each thread + using ThreadAccessShape = cutlass::layout::PitchLinearShape<4, 4>; + //using ThreadAccessShape = ThreadTileShape; + + /// Number of threads total + static int const kThreads = Threads; + + /// Extract length of each access from Layout + static int const kElementsPerAccess = ThreadAccessShape::kContiguous; + + static_assert(!(kElementsPerAccess % 4) , "kElementsPerAccess, needs to be multiple of 4 (32bits)"); + + /// Internal implementation details + struct Detail { + + static_assert(!(ThreadAccessShape::kContiguous % 4), "ThreadAccessShape, needs to be multiple of 4"); + + static_assert(!(Shape::kContiguous % ThreadAccessShape::kContiguous), ""); + + static_assert(!((Shape::kContiguous * Shape::kStrided) % (kThreads * ThreadAccessShape::kCount)), + "Shape must be divisible thread count * accesses per thread."); + + /// Shape of the tile in units of vectors + using ShapeVec = layout::PitchLinearShape< + Shape::kContiguous / ThreadAccessShape::kContiguous, + Shape::kStrided / ThreadAccessShape::kStrided + >; + + static_assert( + (Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || + (!(kThreads % ShapeVec::kContiguous) && !(ShapeVec::kStrided % (kThreads / ShapeVec::kContiguous))), + "Shape must be divisible by number of iterations of each thread." + ); + }; + + /// Number of iterations by each thread + using Iterations = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape< + 1, + // Redo the comparison here to work around divide by zero compiler + // error. The compiler evaluates both path of platform::conditional. + (Threads >= Detail::ShapeVec::kContiguous + ? Detail::ShapeVec::kStrided / + (kThreads / Detail::ShapeVec::kContiguous) + : 0)>, + layout::PitchLinearShape>::type; + + /// Interval between accesses along each dimension of the tensor's logical coordinate space + /// (in units of Elements) + using Delta = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape< + Shape::kContiguous, + kThreads * ThreadAccessShape::kStrided / Detail::ShapeVec::kContiguous + >, + layout::PitchLinearShape< + kThreads * ThreadAccessShape::kContiguous, + 1 + > + >::type; + + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space + /// (in units of Elements) + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + return TensorCoord( + (thread_id % Detail::ShapeVec::kContiguous) * ThreadAccessShape::kContiguous, + (thread_id / Detail::ShapeVec::kContiguous) * ThreadAccessShape::kStrided); + } +}; + +/// Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping +template +struct TransposePitchLinearThreadMap2DThreadTile { + /// Underlying ThreadMap + using ThreadMap = ThreadMap_; + + /// Tensor coordinate + using TensorCoord = typename ThreadMap::TensorCoord; + + /// Tile shape + using Shape = typename ThreadMap::Shape; + + /// Number of threads total + static int const kThreads = ThreadMap::kThreads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + + static_assert(kElementsPerAccess > 1 , "Simt transpose requires elements per access to be 1"); + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = + layout::PitchLinearShape; + + static_assert(Iterations::kCount, "Number of iterations must be non-zero"); + + /// Shape of access by each thread + using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = + layout::PitchLinearShape; + + + /// Maps thread ID to a coordinate offset within the tensor's logical + /// coordinate space Note this is slightly different from the one of + /// PitchLinearWarpRakedThreadMap. + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + TensorCoord coord = ThreadMap::initial_offset(thread_id); + return TensorCoord( + coord.strided(), + coord.contiguous() + ); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h new file mode 100644 index 0000000000000000000000000000000000000000..508cad846e6d6b819c26570e5dcae9844f712089 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Basic copy routines for tensor views +*/ + +#pragma once + +namespace cutlass { +namespace transform { +namespace thread { + +/// Transforms a fragment by doing a transpose +template < + int ElementCount, + typename TransposeShape, + typename Element +> struct Transpose; + +/// Specialization for int8_t 4x4 transpose +template +struct Transpose , int8_t> { + + static const int kElementCount = ElementCount_; + using TransposeShape = layout::PitchLinearShape<4,4>; + using Element = int8_t; + using Fragment = cutlass::Array; + + static_assert(!(kElementCount % TransposeShape::kCount), "Shape needs to be multiple of 16 elements to do a 4x4 transpose"); + + CUTLASS_DEVICE + void transform(Fragment& dst, Fragment& src) { + + // Expose src/dst as int arrays. + int* src_int = reinterpret_cast(&src); + int* dst_int = reinterpret_cast(&dst); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementCount / TransposeShape::kCount; i++){ + + int const i0 = 4 * i + 0; + int const i1 = 4 * i + 1; + int const i2 = 4 * i + 2; + int const i3 = 4 * i + 3; + + int a0 = src_int[i0]; + int a1 = src_int[i1]; + int a2 = src_int[i2]; + int a3 = src_int[i3]; + + int b0, b1, b2, b3, c0; + b0 = __byte_perm(a0, a1, 0x0040); + c0 = __byte_perm(a2, a3, 0x0040); + b0 = __byte_perm(b0, c0, 0x5410); + + b1 = __byte_perm(a0, a1, 0x0051); + c0 = __byte_perm(a2, a3, 0x0051); + b1 = __byte_perm(b1, c0, 0x5410); + + b2 = __byte_perm(a0, a1, 0x0062); + c0 = __byte_perm(a2, a3, 0x0062); + b2 = __byte_perm(b2, c0, 0x5410); + + b3 = __byte_perm(a0, a1, 0x0073); + c0 = __byte_perm(a2, a3, 0x0073); + b3 = __byte_perm(b3, c0, 0x5410); + + dst_int[i0] = b0; + dst_int[i1] = b1; + dst_int[i2] = b2; + dst_int[i3] = b3; + } + } +}; + +} // namespace thread +} // namespace layout +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3977af529124dc3db34610046b72145c2a14bf00 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h @@ -0,0 +1,105 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" + +namespace cutlass { +namespace transform { +namespace thread { + +namespace UnaryTransform { + struct Identity; ///< None (i.e., identity) + struct Conjugate; ///< Complex conjugate +} + +/// Element-wise unary operator that transforms one element of a fragment at a time +template< + typename FragmentIn, ///< Input Fragment + typename FragmentOut,///< Output Fragment + typename Transform> ///< Unary transform operator +class UnaryOp +{ + public: + CUTLASS_DEVICE + static FragmentOut execute(FragmentIn &in) + { + static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match."); + static_assert(platform::is_same::value || + platform::is_same::value, + "Unary Operator not supported."); + + FragmentOut out; + if (platform::is_same::value ) + { + CUTLASS_PRAGMA_UNROLL + for (int i=0; i < FragmentIn::kElements; ++i){ + out[i] = static_cast(in[i]); + } + } + else if (platform::is_same::value ) + { + for (int i=0; i < FragmentIn::kElements; ++i){ + out[i] = conj(static_cast(in[i])); + } + } + return out; + } +}; + +template +class UnaryOp +{ + public: + CUTLASS_DEVICE + static FragmentIn execute(FragmentIn &in) + { + static_assert(platform::is_same::value || + platform::is_same::value, + "Unary Operator not supported."); + + if (platform::is_same::value ) + { + return in; + } + else if (platform::is_same::value ) + { + for(int i=0; i < FragmentIn::kElements; ++i){ + in[i] = conj(in[i]); + } + } + return in; + } + }; + } + } +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..bd717d678f8234b9fd39f7d22c4de5c231da4c42 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h @@ -0,0 +1,199 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Ell iterator for matrix of indices (ellColInd matrix) +*/ + +#pragma once + +namespace cutlass { +namespace transform { +namespace threadblock { + +namespace ell{ + +constexpr unsigned int SmemPow = 8; +constexpr unsigned int SmemStages = 2; +constexpr unsigned int SmemSize = 1 << SmemPow; +constexpr unsigned int SmemMask = (SmemSize*SmemStages-1); + +class SharedStorage{ + public: + Array array; +}; + +class Iterator{ + public: + using Layout = layout::PitchLinear; + using LongIndex = typename Layout::LongIndex; + + private: + const int *gmem_col_idx_; + int *smem_col_idx_; + const int block_size_; + const int base_idx_; + const int k_shape_; + const int ell_increment_; + const int array_length_; + int col_idx_base_; + int residue_; + int counter_; + + int pow2_; + int residue_shape_; + + int smem_offset_; + int smem_stage_; + int gmem_offset_; + + int lane_; + + bool is_pow2_; + bool is_residue_tile_; + + public: + CUTLASS_DEVICE + void load_ell_indices(){ + for(int i=threadIdx.x; i= 0) ? gmem_col_idx : -1; + } + gmem_offset_ += SmemSize; + smem_stage_ ^= 1; + } + + CUTLASS_DEVICE + Iterator( + SharedStorage& shared_storage_base, + const int* col_idx, + const int& block_size, + const int& base_idx, + const int k_shape, + const int& problem_size_k, + const int& ell_stride, + const int& thread_idx) + : residue_(0), + counter_(0), + smem_offset_(0), + smem_stage_(0), + gmem_offset_(0), + block_size_(block_size), + base_idx_(base_idx), + k_shape_(k_shape), + ell_increment_(ell_stride * block_size), + array_length_((problem_size_k + block_size_ - 1) / block_size_), + residue_shape_(problem_size_k % k_shape_), + is_residue_tile_(residue_shape_ != 0), + smem_col_idx_(reinterpret_cast(&shared_storage_base.array)), + gmem_col_idx_(const_cast(col_idx)), + lane_(thread_idx % 32) { + + load_ell_indices(); + __syncthreads(); + + is_pow2_ = ((block_size_ & (block_size_ - 1)) == 0); + if( is_pow2_ && k_shape <= block_size_ ) lane_ = 0; + + col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_; + + pow2_ = 0; + while(block_size_ >> (pow2_ + 1)) ++pow2_; + } + + CUTLASS_DEVICE + int get_blocksize(){ + return block_size_; + } + + CUTLASS_DEVICE + Iterator &operator++(){ + if(is_residue_tile_){ + residue_ += residue_shape_; + is_residue_tile_ = false; + } else { + residue_ += k_shape_; + } + + if(residue_ < block_size_){ + return *this; + } + + if((array_length_ > SmemSize) && (((smem_offset_ >> SmemPow) & 1) != smem_stage_)) + load_ell_indices(); + + if(residue_ == block_size_){ + ++smem_offset_; + counter_ += ell_increment_; + residue_ = 0; + col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; + return *this; + } + + if(is_pow2_){ + smem_offset_ += residue_ >> pow2_; + counter_ += (residue_ >> pow2_) * ell_increment_; + residue_ = residue_ & ((1 << pow2_) - 1); + } + else { + smem_offset_ += residue_ / block_size_; + counter_ += (residue_ / block_size_) * ell_increment_; + residue_ %= block_size_; + } + + col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; + + return *this; + } + + CUTLASS_DEVICE + LongIndex get_offset(const int& idx) { + int num_jump_tiles; + if(is_pow2_) + num_jump_tiles = (idx + residue_) >> pow2_; + else + num_jump_tiles = (idx + residue_) / block_size_; + + int tmp = __shfl_sync(0xffffffff, col_idx_base_, num_jump_tiles); + return tmp - num_jump_tiles * ell_increment_; + } + + CUTLASS_DEVICE + LongIndex get_offset_fast() { + return col_idx_base_; + } +}; + +} +} +} +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..3676c2339067f9eaad667e11e0d798ae3f4d5c95 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h @@ -0,0 +1,1350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaMultistage +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// EllPredicatedTileAccessIterator +/// +template +class EllPredicatedTileAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. +/// +template +class EllPredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static int const kPredicatesPerByte = 4; + static int const kPredicatesPerWord = 4 * kPredicatesPerByte; + + static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; + + /// Number of 32b words containing predicates + static int const kPredicateByteCount = + (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; + static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; + + static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; + + static_assert(kPredicateWordCount <= 4, "Too many predicates."); + + /// Predicate vector stores mask to guard accesses + using Mask = Array; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend EllPredicatedTileAccessIterator; + + private: + /// stride of pitch-linear layout (units of Element) + LongIndex stride_; + /// amount (in byte) to increment pointer to move to next access along + /// strided dimension + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + + // Default ctor + CUTLASS_HOST_DEVICE + Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : stride_(layout.stride(0)) { + inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = + Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * + ThreadMap::Delta::kStrided * LongIndex(stride_) * + sizeof_bits::value / 8; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Guard predicates + uint32_t predicates_[kPredicateWordCount]; + + /// Size of tensor + TensorCoord extent_; + + /// Initial offset for each thread + TensorCoord thread_offset_; + + /// Offset to the first steady-state tile + TensorCoord residue_offset_; + + /// Initial offset to define ELL block + TensorCoord ell_offset_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + /// Iteration along vectors implied by the thread map + int iteration_vector_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0u; + } + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { + + int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int c = access_residual / kAccessesPerVector; + int v = access_residual % kAccessesPerVector; + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = thread_offset_ + iteration_coord; + + bool guard; + + if (is_steady_state) { + if (kAdvanceRank == 0) { + guard = (coord.strided() < extent.strided()); + } else { + guard = (coord.contiguous() < extent.contiguous()); + } + } else { + guard = (coord.strided() < extent.strided() && + coord.contiguous() < extent.contiguous()); + } + + int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + extent_(extent), + is_residue_tile_(true) { + + TensorCoord residue_extent; + if (kAdvanceRank) { + + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; + if (!residue_size) { + residue_size = Shape::kStrided; + } + + residue_offset_ = make_Coord(0, residue_size); + residue_extent = make_Coord( + extent_.contiguous(), + min(threadblock_offset.strided() + residue_size, extent_.strided()) + ); + } else { + + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; + if (!residue_size) { + residue_size = Shape::kContiguous; + } + + residue_offset_ = make_Coord(residue_size, 0); + + residue_extent = make_Coord( + min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), + extent_.strided() + ); + } + + // Per-thread offset in logical coordinates of tensor + ell_offset_ = ThreadMap::initial_offset(thread_id); + thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(thread_offset_)); + + compute_predicates_(residue_extent, false); + + set_iteration_index(0); + } + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + if (is_residue_tile_) { + + thread_offset_ += residue_offset_; + + Layout layout(params_.stride_); + add_pointer_offset(layout(residue_offset_)); + + compute_predicates_(extent_, true); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast( + pointer_ + + iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; + } + + /// Returns a k_location + CUTLASS_HOST_DEVICE + int get_k() const { + if(kAdvanceRank){ //strided + return ell_offset_.strided() + iteration_strided_ * ThreadMap::Delta::kStrided; + }else{ + return ell_offset_.contiguous() + iteration_contiguous_ * ThreadMap::Delta::kContiguous + iteration_vector_ * AccessType::kElements; + } + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + if(kAdvanceRank) + return params_.stride_; + else + return 1; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + + iteration_vector_ = 0; + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = enable ? 0u : predicates_[i]; + } + + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0xffffffff; + } + + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = mask[i]; + } + + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = predicates_[i]; + } + } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + + Mask mask; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = 0u; + } + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { + + int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int c = access_residual / kAccessesPerVector; + int v = access_residual % kAccessesPerVector; + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = ell_offset_ + iteration_coord; + + bool guard; + + if (kAdvanceRank == 0) { + guard = (coord.strided() < blocksize); + } else { + guard = (coord.contiguous() < blocksize); + } + + int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + mask[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] &= predicates_[i]; + } + set_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + + int pred_idx = + iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + return pred; + + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for column-major interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class EllPredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for row-major interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, + AccessType>; + + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..e377bba4c454267737bffda73b1dff7572174ee7 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h @@ -0,0 +1,1315 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaPipelined +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +#include "cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h" +#include "cutlass/transform/threadblock/ell_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// EllPredicatedTileIterator +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize register liveness +/// and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is constructed. +/// Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. +/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be partially full in +/// both the advance dimension and the steady-state dimension. This is assumed to be the last +/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to +/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent +/// accesses may be performed without updating internal predicates and are efficient in terms of +/// live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once +/// outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = transform::threadblock::EllPredicatedTileIterator; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess +> +class EllPredicatedTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + using AccessType = AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + EllPredicatedTileAccessIterator; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend EllPredicatedTileIterator; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) { } + + CUTLASS_HOST_DEVICE + Params() { } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset) {} + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return address_iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { address_iterator_.ell_add_mask(blocksize); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator &ell_iter) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + address_iterator_.set_iteration_index(idx); + LongIndex ell_offset = 0; + + int k_offset = address_iterator_.get_k(); + ell_offset = ell_iter.get_offset(k_offset) * sizeof(Element); + + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + bool is_valid = address_iterator_.valid(); + is_valid = is_valid && (ell_offset >= 0); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, is_valid); + + ++address_iterator_; + } + } + } + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator &ell_iter) { + + LongIndex ell_offset = ell_iter.get_offset_fast() * sizeof(Element); + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + bool is_valid = address_iterator_.valid(); + is_valid = is_valid && (ell_offset >= 0); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, is_valid); + + ++address_iterator_; + } + } + } + } + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class EllPredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { + + } + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) + ) { } + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index(frag, ell_iter); + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index_fast(frag, ell_iter); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class EllPredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { + + }; + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) + ) { } + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index(frag, ell_iter); + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index_fast(frag, ell_iter); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for interleaved data. It is mapped +/// to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class EllPredicatedTileIterator, + AdvanceRank, ThreadMap_, AccessSize> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; + + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index(frag, ell_iter); + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index_fast(frag, ell_iter); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for interleaved-32 data. It is +/// mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileIterator, + AdvanceRank, ThreadMap_, AccessSize> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; + + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..dab597c835ced1a4f070858b26da3007d268c04e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h @@ -0,0 +1,375 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates calculating the address and predicates to the load of scale and bias vectors. + + This iterator uses masks to guard out-of-bounds accesses. + + It can be used to load the gamma and beta vectors of layernorm which is loop variant. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedScaleBiasVectorAccessIterator +/// +template +class PredicatedScaleBiasVectorAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. +/// +template +class PredicatedScaleBiasVectorAccessIterator { + public: + + using ThreadblockShape = ThreadblockShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kElementsPerAccess = 128 / sizeof_bits::value; + static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; + + using AccessType = AlignedArray; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Internal pointer to first access of tile + BytePointer pointer_; + + TensorCoord thread_offset_; + + int problem_size_k_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + bool guard_; + + TensorCoord::Index residue_size_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Extent of tensor + int problem_size_k, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) { + pointer_ = (thread_id < kThreads) + ? reinterpret_cast( + const_cast(scale_pointer)) + : reinterpret_cast( + const_cast(bias_pointer)); + + // Per-thread offset in logical coordinates of tensor + int thread_base = (thread_id < kThreads) ? 0 : kThreads; + + problem_size_k_ = problem_size_k; + + is_residue_tile_ = true; + + residue_size_ = (problem_size_k_ - threadblock_offset.contiguous()) % ThreadblockShape::kContiguous; + + if (residue_size_ == 0) { + residue_size_ = ThreadblockShape::kContiguous; + } + + guard_ = ((thread_id - thread_base) * kElementsPerAccess) < residue_size_; + + thread_offset_ = + threadblock_offset + + TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); + + set_iteration_index(0); + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Extent of tensor + int problem_size_k, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorAccessIterator(problem_size_k, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + + guard_ = threadIdx.x < kThreads * 2; + + TensorCoord offset = is_residue_tile_ ? + TensorCoord(residue_size_ + ThreadblockShape::kContiguous * (tile_offset.contiguous() - 1), 0) + : TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); + + thread_offset_ = + thread_offset_ + + offset; + + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + return reinterpret_cast( + pointer_ + + (thread_offset_.contiguous() * sizeof_bits::value / 8)); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_DEVICE + PredicatedScaleBiasVectorAccessIterator operator++(int) { + PredicatedScaleBiasVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + guard_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return guard_; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedScaleBiasVectorAccessIterator { + public: + + using ThreadblockShape = ThreadblockShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + ///< Extent of tensor + int problem_size_k, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(problem_size_k, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + int problem_size_k, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorAccessIterator(problem_size_k, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// threadblock tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator operator++(int) { + PredicatedScaleBiasVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..e5d9e70d73bfcbdc27ab78bbedea1278c3b25950 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h @@ -0,0 +1,328 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates calculating the address and predicates to the load of scale and bias vectors. + + This iterator uses masks to guard out-of-bounds accesses. + + This can be used to load var and mean vectors in layernorm which is loop invariant. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedScaleBiasVectorIterator +/// +template +class PredicatedScaleBiasVectorIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. +/// +template +class PredicatedScaleBiasVectorIterator { + public: + + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kElementsPerAccess = 1; + + using AccessType = AlignedArray; + + static int const kIterations = WarpShape::kContiguous / 8; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; + + private: + // + // Data members + // + + /// Internal pointer to first access of tile + ConstPointer scale_pointer_; + ConstPointer bias_pointer_; + + /// Size of tensor + int problem_size_; + + int32_t thread_offset_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + /// Extent of tensor + int problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : problem_size_(problem_size), + scale_pointer_(scale_pointer), + bias_pointer_(bias_pointer) { + + thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; + } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + /// Extent of tensor + int problem_size, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorIterator(problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + + thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.fill(__float2half2_rn(0.0f)); + __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); + + // load scale + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + __half, + sizeof(AccessType) + >( + frag_ptr[c * 2].x, + scale_pointer_ + thread_offset_ + c * 8, + (thread_offset_ + c * 8) < problem_size_ + ); + } + + // load bias + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + __half, + sizeof(AccessType) + >( + frag_ptr[c * 2 + 1].x, + bias_pointer_ + thread_offset_ + c * 8, + (thread_offset_ + c * 8) < problem_size_ + ); + } + + // duplicate scale + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + frag_ptr[c * 2].y = frag_ptr[c * 2].x; + } + + // duplicate bias + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedScaleBiasVectorIterator { + public: + + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedScaleBiasVectorIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + using Fragment = typename UnderlyingIterator::Fragment; + + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + ///< Extent of tensor + int problem_size, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(problem_size, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + int problem_size, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorIterator(problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// threadblock tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + iterator_.load(frag); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..3640709868602584f93e3409a251c0baff19d18d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -0,0 +1,2118 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile this + iterator visits maybe partial, then the remaining tiles are complete. So, we + only need to compute the predicates twice, once before the first tile and + once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorPredicates +/// +template +class PredicatedTileAccessIteratorPredicates { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorCoord = typename Layout::TensorCoord; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static int const kPredicatesPerByte = 4; + static int const kPredicatesPerWord = 4 * kPredicatesPerByte; + + static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; + + /// Number of 32b words containing predicates + static int const kPredicateByteCount = + (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; + static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; + + static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; + + static_assert(kPredicateWordCount <= 4, "Too many predicates."); + + /// Predicate vector stores mask to guard accesses + using Mask = Array; + +// private: + /// Guard predicates + uint32_t predicates_[kPredicateWordCount]; + + /// Size of tensor + TensorCoord extent_; + + /// Initial offset for each thread + TensorCoord thread_offset_; + + /// Offset to the first steady-state tile + TensorCoord residue_offset_; + + /// Iteration along vectors implied by the thread map + int iteration_vector_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0u; + } + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { + + int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int c = access_residual / kAccessesPerVector; + int v = access_residual % kAccessesPerVector; + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = thread_offset_ + iteration_coord; + + bool guard; + + if (is_steady_state) { + if (kAdvanceRank == 0) { + guard = (coord.strided() < extent.strided()); + } else { + guard = (coord.contiguous() < extent.contiguous()); + } + } else { + guard = (coord.strided() < extent.strided() && + coord.contiguous() < extent.contiguous()); + } + + int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + + } + + CUTLASS_HOST_DEVICE + void set_predicates(int thread_id, TensorCoord const &threadblock_offset) { + + TensorCoord residue_extent; + if (kAdvanceRank) { + + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; + if (!residue_size) { + residue_size = Shape::kStrided; + } + + residue_offset_ = make_Coord(0, residue_size); + residue_extent = make_Coord( + extent_.contiguous(), + min(threadblock_offset.strided() + residue_size, extent_.strided()) + ); + } else { + + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; + if (!residue_size) { + residue_size = Shape::kContiguous; + } + + residue_offset_ = make_Coord(residue_size, 0); + + residue_extent = make_Coord( + min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), + extent_.strided() + ); + } + + // Per-thread offset in logical coordinates of tensor + thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); + + compute_predicates_(residue_extent, false); + + set_iteration_index(0); + } + + /// Default constructor + PredicatedTileAccessIteratorPredicates() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorPredicates( + /// Extent of tensor + TensorCoord extent) + : extent_(extent) { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorPredicates &operator++() { + + return *this; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = enable ? 0u : predicates_[i]; + } + + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0xffffffff; + } + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = mask[i]; + } + + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = predicates_[i]; + } + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + + + int pred_idx = + iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + return pred; + + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIterator +/// +template +class PredicatedTileAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for pitch-linear data. +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, Element, Layout, AdvanceRank, ThreadMap, AccessType>; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static bool constexpr Permute = !platform::is_same::value + && !platform::is_same>::value; + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + + using Base = PredicatedTileAccessIteratorParams; + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : + Base(layout.stride(0), + MakePredicatedTileAccessIteratorDesc()() + ) { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset separated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + + /// Gather indices + int const *indices_; + + /// Function to perform layout permutation and offset computation + PermuteLayout permute_layout_; + + /// Tracks thread's coordinate offset in the matrix for current tile. + /// This is only used in the following cases: + /// - when Gather is true, strided coordinate needed to access indices (contiguous offset is tracked via pointer_) + /// - when Permute is true, both coordinates are needed as input into permutation function (pointer_ is fixed) + TensorCoord coord_offset_; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + /// Gather indices + int const *indices = nullptr) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + is_residue_tile_(true), + indices_(indices), + permute_layout_(TensorCoord(extent.contiguous(), extent.strided()), params.stride_) { + + the_predicates.set_predicates(thread_id, threadblock_offset); + + if (Gather) { + assert(indices_); + } + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather && !Permute) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + coord_offset_ = the_predicates.thread_offset_; + if (!Permute) { + add_pointer_offset(layout(make_Coord(coord_offset_.contiguous(), 0))); + } + } + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + if (is_residue_tile_) { + + the_predicates.thread_offset_ += the_predicates.residue_offset_; + + the_predicates.compute_predicates_(the_predicates.extent_, true); + + Layout layout(params_.stride_); + + if (!Gather && !Permute) { + add_pointer_offset(layout(the_predicates.residue_offset_)); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); + pointer_ += Shape::kContiguous * tile_offset.contiguous() * sizeof_bits::value / 8; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); + pointer_ += Shape::kStrided * tile_offset.strided() * sizeof_bits::value / 8; + } + } else { + coord_offset_.strided() = the_predicates.thread_offset_.strided() + Shape::kStrided * (tile_offset.strided() - kAdvanceRank); + if (!Permute) { + add_pointer_offset(layout(make_Coord(the_predicates.residue_offset_.contiguous(), 0))); + add_pointer_offset(Shape::kContiguous * (tile_offset.contiguous() - (1 - kAdvanceRank))); + } else { + coord_offset_.contiguous() = the_predicates.thread_offset_.contiguous() + Shape::kContiguous * (tile_offset.contiguous() - (1 - kAdvanceRank)); + } + } + } else { + if (!Gather && !Permute) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + coord_offset_.strided() += Shape::kStrided * tile_offset.strided(); + if (!Permute) { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + } else { + coord_offset_.contiguous() += Shape::kContiguous * tile_offset.contiguous(); + } + } + } + + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + if (Gather || Permute) + { + if (!valid()) { + return nullptr; + } + + Index coord_contig = (Permute ? coord_offset_.contiguous() : 0) + the_predicates.iteration_contiguous_ * ThreadMap::Delta::kContiguous + the_predicates.iteration_vector_ * AccessType::kElements; + Index coord_strided = coord_offset_.strided() + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + if (Gather) { + coord_strided = indices_[coord_strided]; + } + + LongIndex offset = Permute ? permute_layout_(TensorCoord(coord_contig, coord_strided)) : (coord_strided * LongIndex(params_.stride_) + coord_contig); + return reinterpret_cast(pointer_ + OffsetBytes(offset)); + } + + return reinterpret_cast( + pointer_ + + the_predicates.iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather && !Permute) { + pointer_ += params_.inc_strided_; + } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather && !Permute) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType, + Gather, PermuteLayout>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType, + Gather, PermuteLayout>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset, + /// Gather indices + int const *indices = nullptr) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for affine rank 2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap, AccessType>; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIterator; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + + // Default ctor + CUTLASS_HOST_DEVICE + Params(): stride_(0), inc_contiguous_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = inc_strided_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = + Shape::kStrided * LongIndex(stride_[1]) * sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + is_residue_tile_(true) { + + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { the_predicates.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + if (is_residue_tile_) { + + the_predicates.thread_offset_ += the_predicates.residue_offset_; + + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.residue_offset_)); + + the_predicates.compute_predicates_(the_predicates.extent_, true); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1] - 1); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0] - 1); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } else { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(pointer_) + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { the_predicates.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { the_predicates.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { the_predicates.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for affine rank 2 column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::AffineRankN<2>, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset(make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for affine rank-2 row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::AffineRankN<2>, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset(make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class PredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, + AccessType>; + + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h new file mode 100644 index 0000000000000000000000000000000000000000..93eac72e40ddf6b0f3d268957873417e5d5a442f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h @@ -0,0 +1,834 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last + "residue" tile first, with the objective of minimizing predicate mask updates + during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIterator2dThreadTile +/// +template +class PredicatedTileAccessIterator2dThreadTile; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. +/// +template +class PredicatedTileAccessIterator2dThreadTile { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kPredicatesPerByte = 4; + static int const kPredicatesPerWord = 4 * kPredicatesPerByte; + + /// Number of 32b words containing predicates + static int const kPredicateByteCount = (ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kStrided + kPredicatesPerByte - 1) / kPredicatesPerByte; + static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; + + static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; + + static_assert(kPredicateWordCount <= 4, "Too many predicates."); + + /// Predicate vector stores mask to guard accesses + using Mask = Array; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + + public: + friend PredicatedTileAccessIterator2dThreadTile; + + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : + Base(layout.stride(0), + MakePredicatedTileAccessIteratorDesc()() + ) { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Guard predicates + uint32_t predicates_[kPredicateWordCount]; + + /// Size of tensor + TensorCoord extent_; + + /// Initial offset for each thread + TensorCoord thread_offset_; + + /// Index of residue tile + int residue_tile_idx_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + /// Tracks iterations within the thread loop + int iteration_thread_; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_HOST_DEVICE + void compute_predicates_( + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0u; + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++) { + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous, + ts + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = thread_offset_ + iteration_coord; + + bool guard; + + if (is_steady_state) { + if (kAdvanceRank == 0) { + guard = (coord.strided() < extent_.strided()); + } else { + guard = (coord.contiguous() < extent_.contiguous()); + } + } else { + guard = (coord.strided() < extent_.strided() && + coord.contiguous() < extent_.contiguous()); + } + + int pred_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + } + } + + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + extent_(extent), + is_residue_tile_(true) { + + + TensorCoord residue_offset; + if (kAdvanceRank) { + residue_tile_idx_ = + (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) / + Shape::kStrided; + residue_offset = make_Coord(0, residue_tile_idx_ * Shape::kStrided); + } else { + residue_tile_idx_ = + (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) / + Shape::kContiguous; + residue_offset = make_Coord(residue_tile_idx_ * Shape::kContiguous, 0); + } + + // Per-thread offset in logical coordinates of tensor + thread_offset_ = threadblock_offset + residue_offset + + ThreadMap::initial_offset(thread_id); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(thread_offset_)); + + compute_predicates_(false); + + set_iteration_index(0); + } + + /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + int residual = index % (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided); + iteration_strided_ = index / (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided); + + iteration_contiguous_ = residual / ThreadMap::ThreadAccessShape::kStrided; + iteration_thread_ = residual % ThreadMap::ThreadAccessShape::kStrided; + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += int(sizeof(Element)) * pointer_offset; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + if (is_residue_tile_) { + TensorCoord residue_offset; + if (kAdvanceRank) { + residue_offset = TensorCoord(0, residue_tile_idx_ * Shape::kStrided); + } else { + residue_offset = TensorCoord(residue_tile_idx_ * Shape::kContiguous, 0); + } + + thread_offset_ -= residue_offset; + + Layout layout(params_.stride_); + add_pointer_offset(-layout(residue_offset)); + + compute_predicates_(true); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * tile_offset.strided(); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * tile_offset.contiguous(); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } + is_residue_tile_ = false; + } + + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *ret_val = reinterpret_cast( + pointer_ + (iteration_thread_ * params_.stride_ + iteration_contiguous_ * ThreadMap::Delta::kContiguous) * int(sizeof(Element))); + + return ret_val; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile &operator++() { + + iteration_thread_++; + + if (iteration_thread_ < ThreadMap::ThreadAccessShape::kStrided) + return *this; + + iteration_thread_ = 0; + + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile operator++(int) { + PredicatedTileAccessIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = enable ? 0u : predicates_[i]; + } + + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0xffffffff; + } + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = mask[i]; + } + + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = predicates_[i]; + } + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + + int pred_idx = + iteration_thread_ + + iteration_contiguous_ * ThreadMap::ThreadAccessShape::kStrided + + iteration_strided_ * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + + return pred; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator2dThreadTile { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator2dThreadTile; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile operator++(int) { + PredicatedTileAccessIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator2dThreadTile { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator2dThreadTile; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile operator++(int) { + PredicatedTileAccessIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h new file mode 100644 index 0000000000000000000000000000000000000000..5e509a344e955438ea4eabe6806ed2ab79343d36 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/detail/helper_macros.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Predicated tile access iterator descriptor object containing template dependent state +struct PredicatedTileAccessIteratorDesc { + + int element_size_bits = -1; + int advance_rank = -1; + layout::PitchLinearCoord threadblock_shape; + layout::PitchLinearCoord threadmap_iterations; + layout::PitchLinearCoord threadmap_delta; + + // + // Methods + // + + PredicatedTileAccessIteratorDesc() = default; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc( + int element_size_bits_, + int advance_rank_, + layout::PitchLinearCoord threadblock_shape_, + layout::PitchLinearCoord threadmap_iterations_, + layout::PitchLinearCoord threadmap_delta_ + ): + element_size_bits(element_size_bits_), + advance_rank(advance_rank_), + threadblock_shape(threadblock_shape_), + threadmap_iterations(threadmap_iterations_), + threadmap_delta(threadmap_delta_) + { + #if 0 + printf("PredicatedTileAccessIteratorDesc(%d, %d, {%d, %d}, {%d, %d}, {%d, %d}})\n", + element_size_bits, + advance_rank, + threadblock_shape.contiguous(), threadblock_shape.strided(), + threadmap_iterations.contiguous(), threadmap_iterations.strided(), + threadmap_delta.contiguous(), threadmap_delta.strided()); + #endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper template to construct an PredicatedTileAccessIteratorDesc from a template +// dependent state +template < + typename Shape, typename Element, typename Layout, + int AdvanceRank, typename ThreadMap> + struct MakePredicatedTileAccessIteratorDesc; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for pitch-linear data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap> { + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return PredicatedTileAccessIteratorDesc( + sizeof_bits::value, + AdvanceRank, + {Shape::kContiguous, Shape::kStrided}, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ); +} + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for column-major data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::ColumnMajor, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::RowMajor, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap, int InterleavedK> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::ColumnMajorInterleaved, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + static int const kInterleavedK = InterleavedK; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for roww-major interleaved data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap, int InterleavedK> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::RowMajorInterleaved, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + static int const kInterleavedK = InterleavedK; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Parameters struct +// + +struct PredicatedTileAccessIteratorParams { + + using Index = int32_t; + using LongIndex = int64_t; + + // + // Data members + // + /// stride of pitch-linear layout (units of Element) + LongIndex stride_ = 0; + /// amount (in byte) to increment pointer to move to next access along + /// strided dimension + LongIndex inc_strided_ = 0; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_ = 0; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Status initialize(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { + CUTLASS_ASSERT(desc.element_size_bits > 0); + CUTLASS_ASSERT(desc.advance_rank == 0 || desc.advance_rank == 1); + + stride_ = stride; + + inc_strided_ = (LongIndex(stride_) * desc.threadmap_delta.strided()) * + desc.element_size_bits / 8; + + if (desc.advance_rank) { + // advance along strided dimension + inc_advance_ = + desc.threadblock_shape.strided() * LongIndex(stride_) * desc.element_size_bits / 8; + } else { + // advance along contiguous dimension + inc_advance_ = desc.threadblock_shape.contiguous() * desc.element_size_bits / 8; + } + + inc_next_ = inc_advance_ - LongIndex(desc.threadmap_iterations.strided() - 1) * + desc.threadmap_delta.strided() * LongIndex(stride_) * + desc.element_size_bits / 8; + + return Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Status initialize(Index stride, PredicatedTileAccessIteratorDesc desc) { + return initialize(LongIndex(stride), desc); + } + + PredicatedTileAccessIteratorParams() = default; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorParams(Index stride, PredicatedTileAccessIteratorDesc desc) { + initialize(stride, desc); + } + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorParams(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { + initialize(stride, desc); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..f657fe25813567b47156047f6ef023b678ac097f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h @@ -0,0 +1,892 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last + "residue" tile first, with the objective of minimizing predicate mask updates + during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorTriangularMatrix +/// +template +class PredicatedTileAccessIteratorTriangularMatrix; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for pitch-linear data. +/// +template +class PredicatedTileAccessIteratorTriangularMatrix { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + using CompareOp = typename TrMatrixCompareOp::Type; + + static_assert( kFillMode == FillMode::kFull || + ((kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) && AccessType::kElements == 1), + "BLAS3 iterator for the triangular/symmetric matrix must use AccessType::kElements as 1"); + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static int const kPredicatesPerByte = 4; + static int const kPredicatesPerWord = 4 * kPredicatesPerByte; + + static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; + + /// Number of 32b words containing predicates + static int const kPredicateByteCount = + (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; + static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; + + static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; + + static_assert(kPredicateWordCount <= 4, "Too many predicates."); + + /// Predicate vector stores mask to guard accesses + using Mask = Array; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorTriangularMatrix; + + private: + /// stride of pitch-linear layout (units of Element) + StrideIndex stride_; + /// (true) pitch-linear layout is mapped to row-major matrix + /// (false) pitch-linear layout is mapped to column-major matrix + bool is_row_major_; + /// for vectorized access across the diagonal boundary guard condition is + /// checked for the element on the boundary + int access_diagonal_boundary_; + /// amount (in byte) to increment pointer to move to next access along + /// strided dimension + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + + // Default ctor + CUTLASS_HOST_DEVICE + Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0), is_row_major_(false), access_diagonal_boundary_(0) { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout, bool is_row_major, int access_diagonal_boundary) : + stride_(layout.stride(0)), is_row_major_(is_row_major), access_diagonal_boundary_(access_diagonal_boundary) { + + inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = + Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * + ThreadMap::Delta::kStrided * LongIndex(stride_) * + sizeof_bits::value / 8; + + }; + + + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Guard predicates + uint32_t predicates_[kPredicateWordCount]; + + /// Track global memory addresses on the diagonal + /// To ignore imag part for diagonal elements of hermitian matrices + uint32_t predicates_onDiag_[kPredicateWordCount]; + + /// Size of tensor + TensorCoord extent_; + + /// Initial offset for each thread + TensorCoord thread_offset_; + + /// Iteration along vectors implied by the thread map + int iteration_vector_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0u; + predicates_onDiag_[i] = 0u; + } + + CompareOp compare_op; + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { + + int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int c = access_residual / kAccessesPerVector; + int v = access_residual % kAccessesPerVector; + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = thread_offset_ + iteration_coord; + + bool guard; + bool onDiag = false; + + guard = ((coord.strided() < extent.strided()) && + (coord.contiguous() < extent.contiguous())); + + + // guard access on the wrong side of the triagular matrix diagonal + if (kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) { + coord += TensorCoord{params_.access_diagonal_boundary_, 0}; + + bool triagular_guard_row_major = compare_op(coord.strided(), coord.contiguous()) | !params_.is_row_major_; + bool triagular_guard_col_major = compare_op(coord.contiguous(), coord.strided()) | params_.is_row_major_; + + guard = guard && triagular_guard_row_major && triagular_guard_col_major; + + if (kDiagType == DiagType::kUnit) { + onDiag = (guard && coord.strided() == coord.contiguous()) ? true : false; + } + } + + int pred_idx_onDiag = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + int word_idx_onDiag = pred_idx_onDiag / kPredicatesPerWord; + int residual_onDiag = pred_idx_onDiag % kPredicatesPerWord; + int byte_idx_onDiag = residual_onDiag / kPredicatesPerByte; + int bit_idx_onDiag = residual_onDiag % kPredicatesPerByte; + + predicates_onDiag_[word_idx_onDiag] |= (unsigned(onDiag) << (byte_idx_onDiag * 8 + bit_idx_onDiag)); + + int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + pointer_(reinterpret_cast(const_cast(pointer))), + extent_(extent) { + + + // Per-thread offset in logical coordinates of tensor + thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(thread_offset_)); + + compute_predicates_(extent_); + + set_iteration_index(0); + } + + /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + thread_offset_ += TensorCoord{0, Shape::kStrided * tile_offset.strided()}; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + thread_offset_ += TensorCoord{Shape::kContiguous * tile_offset.contiguous(), 0}; + } + + compute_predicates_(extent_); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast( + pointer_ + + iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix &operator++() { + + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + + iteration_vector_ = 0; + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix operator++(int) { + PredicatedTileAccessIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = enable ? 0u : predicates_[i]; + } + + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0xffffffff; + } + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = mask[i]; + } + + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = predicates_[i]; + } + } + + /// Return if the address in on the diagonal + CUTLASS_HOST_DEVICE + bool getOnDiag() { + int pred_idx = + iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_onDiag_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + return pred; + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + + + int pred_idx = + iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + return pred; + + + //return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorTriangularMatrix { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, + kSideMode, kFillMode, kDiagType, AccessType>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + static int const kAccessDiagonalBoundary = + (kFillMode == FillMode::kLower) ? (AccessType::kElements - 1) : 0; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorTriangularMatrix; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0)), false, kAccessDiagonalBoundary){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix operator++(int) { + PredicatedTileAccessIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Return if the address in on the diagonal + CUTLASS_HOST_DEVICE + bool getOnDiag() { + return iterator_.getOnDiag(); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorTriangularMatrix { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, + kSideMode, kFillMode, kDiagType, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + static int const kAccessDiagonalBoundary = + (kFillMode == FillMode::kUpper) ? (AccessType::kElements - 1) : 0; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorTriangularMatrix; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0)), true, kAccessDiagonalBoundary){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix operator++(int) { + PredicatedTileAccessIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Return if the address in on the diagonal + CUTLASS_HOST_DEVICE + bool getOnDiag() { + return iterator_.getOnDiag(); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..43c4cbd1a5758e0288f82babbe7043d22f83c009 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h @@ -0,0 +1,1887 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile this + iterator visits maybe partial, then the remaining tiles are complete. So, we + only need to compute the predicates twice, once before the first tile and + once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIterator +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize register liveness +/// and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is constructed. +/// Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. +/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be partially full in +/// both the advance dimension and the steady-state dimension. This is assumed to be the last +/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to +/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent +/// accesses may be performed without updating internal predicates and are efficient in terms of +/// live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once +/// outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = transform::threadblock::PredicatedTileIterator; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess, + bool Gather = false, + typename PermuteLayout = layout::NoPermute +> +class PredicatedTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + using AccessType = AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIterator; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIterator; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) {} + + /// Default constructor + Params() = default; + + CUTLASS_HOST_DEVICE + Params(Base const &base) + : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + /// Gather indices + int const *indices = nullptr) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset, indices) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather, + typename PermuteLayout +> +class PredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather, + PermuteLayout + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) + {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()), + indices) + { } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather, + typename PermuteLayout +> +class PredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather, + PermuteLayout + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + + }; + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< Gather indices + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()), + indices + ) { } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIterator, AdvanceRank, + ThreadMap_, AccessSize, false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + using AccessType = AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIterator; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + + friend PredicatedTileIterator; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) {} + + /// Default constructor + Params() = default; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for affine rank 2 column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class PredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) + {} + }; + +private: + + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) + ) { } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for affine rank 2 row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class PredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + +private: + + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) + ) { } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for interleaved data. It is mapped +/// to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class PredicatedTileIterator, + AdvanceRank, ThreadMap_, AccessSize, false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; + + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for interleaved-32 data. It is +/// mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIterator, + AdvanceRank, ThreadMap_, AccessSize, false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; + + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h new file mode 100644 index 0000000000000000000000000000000000000000..cbe48df6e7dc1c66c9e55b8eab14aa1fb53bc14b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h @@ -0,0 +1,787 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile + first, with the objective of minimizing predicate mask updates during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h" +#include "cutlass/transform/thread/transpose.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIterator2dThreadTile +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize register liveness +/// and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is constructed. +/// Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. +/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. +/// +/// Vistitation order is intended to first visit a "residual" tile that may be partially full in +/// both the advance dimension and the steady-state dimension. This is assumed to be the last +/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to +/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent +/// accesses may be performed without updating internal predicates and are efficient in terms of +/// live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once +/// outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = transform::threadblock::PredicatedTileIterator2dThreadTile; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + bool Transpose = false +> +class PredicatedTileIterator2dThreadTile; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIterator2dThreadTile { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + /// extra set of parenthesis is needed for VS compiler + struct alignas((ThreadMap::kElementsPerAccess * sizeof_bits::value / + 8)) AccessType { + + Array storage; + + static int const kElements = ThreadMap::kElementsPerAccess; + }; + + /// Optionally this fragment can be 4x4 transposed + using Transform = thread::Transpose< ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount , layout::PitchLinearShape<4,4>, Element>; + static bool const transpose = Transpose_; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIterator2dThreadTile; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIterator2dThreadTile; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) { } + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) + : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile operator++(int) { + PredicatedTileIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){ + + int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \ + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; + + address_iterator_.set_iteration_index(access_idx); + if (address_iterator_.valid()) { + + frag_ptr[access_idx] = + *(address_iterator_.get() + pointer_offset); + } + + ++address_iterator_; + } + } + } + + if (transpose) { + Transform t; + t.transform(frag, frag); + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){ + + int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \ + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; + + address_iterator_.set_iteration_index(access_idx); + if (address_iterator_.valid()) { + *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + bool Transpose_ +> +class PredicatedTileIterator2dThreadTile { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static bool const Transpose = Transpose_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator2dThreadTile< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + Transpose + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator2dThreadTile; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) + ) { } + + /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile operator++(int) { + PredicatedTileIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + bool Transpose_ +> +class PredicatedTileIterator2dThreadTile { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static bool const Transpose = Transpose_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator2dThreadTile< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + Transpose + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator2dThreadTile; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { } + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) + ) { } + + /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile operator++(int) { + PredicatedTileIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..9bf5e8586675c11bb52e2db5346ff19f489461af --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h @@ -0,0 +1,818 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile + first, with the objective of minimizing predicate mask updates during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorTriangularMatrix +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize register liveness +/// and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is constructed. +/// Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. +/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. +/// +/// Vistitation order is intended to first visit a "residual" tile that may be partially full in +/// both the advance dimension and the steady-state dimension. This is assumed to be the last +/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to +/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent +/// accesses may be performed without updating internal predicates and are efficient in terms of +/// live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once +/// outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = transform::threadblock::PredicatedTileIteratorTriangularMatrix; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + SideMode kSideMode, + FillMode kFillMode, + DiagType kDiagType, + int AccessSize = ThreadMap::kElementsPerAccess +> +class PredicatedTileIteratorTriangularMatrix; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorTriangularMatrix for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorTriangularMatrix { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + using AccessType = AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIteratorTriangularMatrix; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorTriangularMatrix; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) { } + + CUTLASS_HOST_DEVICE + Params() { } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix operator++(int) { + PredicatedTileIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorTriangularMatrix for column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + SideMode kSideMode, + FillMode kFillMode, + DiagType kDiagType, + int AccessSize +> +class PredicatedTileIteratorTriangularMatrix { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + kSideMode, + kFillMode, + kDiagType, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIteratorTriangularMatrix; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { + + } + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) + ) { } + + /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix operator++(int) { + PredicatedTileIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorTriangularMatrix for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + SideMode kSideMode, + FillMode kFillMode, + DiagType kDiagType, + int AccessSize +> +class PredicatedTileIteratorTriangularMatrix { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + kSideMode, + kFillMode, + kDiagType, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIteratorTriangularMatrix; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { + + }; + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) + ) { } + + /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix operator++(int) { + PredicatedTileIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..df551c13f52834bfa6258104f99c7ed008342279 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h @@ -0,0 +1,417 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates implementing computing the addresses of loading small + vectors from the global memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedVectorAccessIterator +/// +template < + /// Shape of the vector accessed by the entire threadblock + typename Shape, + /// Shape of the vector accessed by the warp + typename WarpShape, + /// Type of Element + typename Element, + /// Layout of the vector + typename Layout, + /// Number of elements for each access + int ElementsPerAccess, + /// Support residual tile + bool EnableResidualAccess = false +> +class PredicatedVectorAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Vector access iterator specialized for vectors, e.g. scale and bias +/// Thread arrangements are for TensorOps +/// +template < + typename Shape_, + typename WarpShape_, + typename Element_, + int ElementsPerAccess, + bool EnableResidualAccess +> +class PredicatedVectorAccessIterator < + Shape_, + WarpShape_, + Element_, + layout::PitchLinear, + ElementsPerAccess, + EnableResidualAccess +> { + public: + + using Shape = Shape_; + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + +// static int const kElementsPerAccess = 128 / sizeof_bits::value; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kThreads = 32; + static int const kRowsPerIteration = 8; + static int const kThreadsPerRow = kThreads / kRowsPerIteration; + static int const kThreadsPerRowMask = 0x3; + static int const kIterations = WarpShape::kContiguous / (kThreadsPerRow * kElementsPerAccess); + static int const kWarpCountStrided = Shape::kStrided / WarpShape::kStrided; + + using AccessType = AlignedArray; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Extent of tensor + TensorCoord extent_; + + /// pointer offset of each thread + TensorCoord thread_offset_; + + /// iteration index + LongIndex iteration_; + + /// residual access + bool is_residual_; + + /// residual offset of each thread + TensorCoord residual_offset_; + + public: + /// Constructs a vector access iterator + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator( + /// Pointer to the start of the vector + ConstPointer pointer, + /// Extent of vector + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// ID of each participating warp + int warp_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : pointer_(reinterpret_cast( + const_cast(pointer))), + extent_(extent), + is_residual_(false) { + + + int warp_offset = (warp_id / kWarpCountStrided) * WarpShape::kContiguous; + + // Per-thread offset in logical coordinates of tensor + + thread_offset_ = threadblock_offset + TensorCoord(warp_offset, 0) + + TensorCoord((thread_id & kThreadsPerRowMask) * kElementsPerAccess, 0); + + set_iteration_index(0); + + if(EnableResidualAccess) { + // compute residual offset + typename TensorCoord::Index residual_size = extent_.contiguous() % WarpShape::kContiguous; + if (residual_size) { + is_residual_ = true; + residual_offset_ = make_Coord(residual_size, 0); + } + } + } + + /// Construct a PredicatedVectorAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator( + /// Pointer to start of vector + ConstPointer pointer, + /// Extent of vector + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + /// ID of each participating warp + int warp_id) + : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id, + make_Coord(0, 0)) {} + + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_ = index; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + + thread_offset_ = + thread_offset_ + + TensorCoord(WarpShape::kContiguous * tile_offset.contiguous(), 0); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + return reinterpret_cast( + pointer_ + + ((thread_offset_.contiguous() + iteration_ * kThreadsPerRow * kElementsPerAccess) + * sizeof_bits::value / 8)); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator &operator++() { + ++iteration_; + if(iteration_ >= kIterations) + iteration_ = 0; + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + void advance() { + if(EnableResidualAccess && is_residual_) { + is_residual_ = false; + thread_offset_ += residual_offset_; + } + else + add_tile_offset(TensorCoord(1, 0)); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator operator++(int) { + PredicatedVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return ((thread_offset_.contiguous() + + iteration_ * kThreadsPerRow * kElementsPerAccess) < extent_.contiguous()); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedVectorAccessIterator for row-major data. +/// +template < + typename Shape_, + typename WarpShape_, + typename Element_, + int ElementsPerAccess, + bool EnableResidualAccess +> +class PredicatedVectorAccessIterator< + Shape_, + WarpShape_, + Element_, + layout::RowMajor, + ElementsPerAccess, + EnableResidualAccess +> { + public: + + using Shape = Shape_; + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedVectorAccessIterator< + layout::PitchLinearShape, + layout::PitchLinearShape, + Element, + layout::PitchLinear, + ElementsPerAccess, + EnableResidualAccess>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + static int const kRowsPerIteration = UnderlyingIterator::kRowsPerIteration; + static int const kThreads = UnderlyingIterator::kThreads; + static int const kIterations = UnderlyingIterator::kIterations; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator( + ///< Pointer to the start of the vector + ConstPointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< ID of each participating warp + int warp_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(pointer, layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, warp_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedVectorAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator( + ConstPointer pointer, ///< Pointer to the start of the vector + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int warp_id ///< ID of each participating warp + ) + : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator operator++(int) { + PredicatedVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + void advance() { + iterator_.advance(); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..1aae46988418c72a9322b7e6b47e1dfe4fadff8d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h @@ -0,0 +1,253 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates implementing computing the addresses of storing of small + scale and bias vectors in the shared memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// RegularScaleBiasVectorAccessIterator +/// +template +class RegularScaleBiasVectorAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularScaleBiasVectorAccessIterator { + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + /// Element type per access + static int const kElementsPerAccess = 128 / sizeof_bits::value; + static int const kThreads = Shape::kContiguous / kElementsPerAccess; + using AccessType = Array; + + private: + // + // Data members + // + + /// Internal pointer + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator( + TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias + ///< vector + int thread_id ///< ID of each participating thread + ) + : byte_offset_(0) { + // Per-thread offset in logical coordinates of tensor + int thread_offset = thread_id * kElementsPerAccess; + + // initialize pointer + pointer_ = + reinterpret_cast(scale_bias_ref.data() + thread_offset); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + char *access_byte_ptr = + reinterpret_cast(pointer_); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator &operator++() { return *this; } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator operator++(int) { + RegularScaleBiasVectorAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + // Multiply by 2 because we store scale and bias belong to the same stage + // next to each other. + add_pointer_offset(coord.contiguous() * Shape::kContiguous * 2); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for row major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularScaleBiasVectorAccessIterator< + Shape_, Element_, + layout::RowMajor> { + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + /// Underlying iterator type + using UnderlyingIterator = RegularScaleBiasVectorAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator( + TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias + ///< vector + int thread_id ///< ID of each participating thread + ) + : iterator_({scale_bias_ref.data(), scale_bias_ref.stride()}, thread_id) { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator operator++(int) { + RegularScaleBiasVectorAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..cfb491b5a4b5f4e1b757f99110f6a9fd28675088 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing the address computation of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template ::value* ThreadMap::kElementsPerAccess / 8> +class RegularTileAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h new file mode 100644 index 0000000000000000000000000000000000000000..adda9339b87865799c56baba4c3f8df580e26ac5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h @@ -0,0 +1,408 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing computing the addresses of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::PitchLinear, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kContiguous + + coord.strided() * Shape::kStrided * stride_ * + ThreadMap::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for column major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajor, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for row major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::RowMajor, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..71c89686a71995b45f9d4cf0fd1f0fba12ca7d8a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h @@ -0,0 +1,587 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing computing the addresses of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + + +//////////////////////////////////////////////////////////////////////////////// + +template ::value* ThreadMap::kElementsPerAccess / 8 + > +class RegularTileAccessIteratorDirectConv; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations OFF +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::PitchLinear, + AdvanceRank, ThreadMap_, false, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + //Do nothing + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kContiguous + + coord.strided() * ThreadMap::Iterations::kStrided * + ThreadMap::Delta::kStrided * stride_ * ThreadMap::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations ON +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::PitchLinear, + AdvanceRank, ThreadMap_,true, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + /// Total iterattions in the strided dimension: Dynamic value + int total_iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + total_iteration_strided_ = num; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < total_iteration_strided_) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kContiguous + + coord.strided() * total_iteration_strided_ * ThreadMap::Delta::kStrided * stride_ * + ThreadMap::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for column major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::ColumnMajor, + AdvanceRank, ThreadMap_, Dynamic_iterations , Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIteratorDirectConv< + layout::PitchLinearShape, Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap_, + Dynamic_iterations>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + iterator_.set_iteration_num(num); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + ++iterator_; + + return prev; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for row major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::RowMajor, + AdvanceRank, ThreadMap_, Dynamic_iterations, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIteratorDirectConv< + layout::PitchLinearShape, Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap_, + Dynamic_iterations>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + iterator_.set_iteration_num(num); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e172447fa96b02e11246f5f397911841c52eff4c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h @@ -0,0 +1,821 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing computing the addresses of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandCongruous::value, + Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + static int const kCrosswise = Crosswise; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + + ///< Number of pointers + static int const kPointerCount = + (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_[Detail::kPointerCount]; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), + byte_offset_(0) { + layout::PitchLinearCoord thread_offset_base = + ThreadMap::initial_offset(thread_id); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + thread_offset_base + + layout::PitchLinearCoord{ + 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; + + // initialize pointer + pointer_[i] = reinterpret_cast( + ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + AccessType *access_ptr = pointer_[iteration_strided_ & 1]; + int stride_idx = (iteration_strided_ & ~1); + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kContiguous * Layout::kFactor + + coord.strided() * Shape::kStrided * stride_ * + Layout::kElementsPerAccess / Layout::kFactor); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::RowMajorTensorOpMultiplicandCongruous::value, + Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for crosswise arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + static int const kCrosswise = Crosswise; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(!(ThreadMap::Delta::kContiguous % kCrosswise), + "kCrosswise is the smallest unit in the contiguous dimension " + "for shared memory swizzling."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + + /// Number of pointers + /// + /// Note:TN kblock32 layouts only needs 1 pointer, but strangely + /// reducing pointer count hurts perfomrnace + static int const kPointerCount = + (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Total number of sections. The memory is divided into stages. One stage + /// can store one tile. Stage is divided into sections. Interleaved layout + /// can have multiple sections in a stage. The rest layout only has one section + /// in a stage. + int sections_; + + /// Sections that a stage has + int sections_per_stage_; + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_[Detail::kPointerCount]; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : sections_(ref.stride(0) / kCrosswise), + sections_per_stage_(Shape::kContiguous / kCrosswise), + // stride_ = kCrosswise x sections_ x kFactor + stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), + byte_offset_(0) { + layout::PitchLinearCoord thread_offset_base = + ThreadMap::initial_offset(thread_id); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + thread_offset_base + + layout::PitchLinearCoord{ + 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; + // initialize pointer + pointer_[i] = reinterpret_cast(ref.data()) + + ref.offset(thread_offset_in_threadblock_tile) / + Layout::kElementsPerAccess; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + AccessType *access_ptr = pointer_[iteration_strided_ & 1]; + int stride_idx = (iteration_strided_ & ~1); + + int access_offset = + stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor + + // kCrosswise elements in the contiguous dimension would span to a + // shared memory cache line. + iteration_contiguous_ * (ThreadMap::Delta::kContiguous / kCrosswise) * + Layout::TileShape::kContiguous; + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) + // which means we enter the next section. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * sections_per_stage_ * stride_ * + ThreadMap::kElementsPerAccess / sections_ + + coord.strided() * Shape::kStrided * stride_ * + Layout::kElementsPerAccess / Layout::kFactor); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..b55f841eee2e09aec8af5c8ec945a1997705c9f6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h @@ -0,0 +1,1532 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing computing the addresses of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCongruous64b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicandCongruous64b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 64; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 64b"); + + ///< Number of pointers + static int const kPointerCount = 1; + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset( + coord.contiguous() * Shape::kContiguous + + coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCongruous64b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous64b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous64b, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCongruous64b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous64b, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for crosswise arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicand64bCrosswise, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicand64bCrosswise; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 64; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 64b"); + + ///< Number of pointers - two pointers are needed if making more than 4 iterations along + ///< strided dimension + static int const kPointerCount = (ThreadMap::Iterations::kStrided > 4 ? 2 : 1); + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_[Detail::kPointerCount]; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / ThreadMap::kElementsPerAccess) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data()); + + byte_offset_[0] = ref.offset(thread_offset_in_threadblock_tile) * sizeof(Element); + + if (Detail::kPointerCount == 2) { + byte_offset_[1] = byte_offset_[0] ^ 8; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + pointer_ += pointer_offset / ThreadMap::kElementsPerAccess; + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + // Map the logical contiguous and strided access to the internal swizzled structure. + int uniform_offset = (iteration_strided_ & 0x3) * stride_ + (iteration_strided_ >> 3) * 16 + stride_ * ThreadMap::Delta::kContiguous * iteration_contiguous_; + + char *access_byte_ptr = reinterpret_cast(pointer_ + uniform_offset); + + int byte_offset; + + // This iterator may require two byte offsets if it must load more than 8 rows (or 2 iterations) + // in the strided dimension + if (Detail::kPointerCount == 2 && (iteration_strided_ & 0x4)) { + byte_offset = byte_offset_[1]; + } + else { + byte_offset = byte_offset_[0]; + } + + return reinterpret_cast(access_byte_ptr + byte_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset(coord.strided() * Shape::kStrided + coord.contiguous() * Shape::kContiguous * stride_); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicand64bCrosswise, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicand64bCrosswise, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicand64bCrosswise; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicand64bCrosswise, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCongruous128b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicandCongruous128b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128b"); + + ///< Number of pointers + static int const kPointerCount = 1; + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset( + coord.contiguous() * Shape::kContiguous + + coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCongruous128b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous128b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous128b, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCongruous128b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous128b, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCrosswise128x4, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicandCrosswise128x4; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128b"); + + ///< Number of pointers + static int const kPointerCount = 1; + }; + + + static_assert(!(ThreadMap::Iterations::kStrided % 2), "This iterator requires at least two iterations along the strided dimension"); + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int offset_c = (iteration_contiguous_ * ThreadMap::Delta::kContiguous + (iteration_strided_ & 1) * 2); + int offset_s = (iteration_strided_ / 2) * 8; + + int access_offset = offset_c * stride_ + offset_s; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset( + coord.contiguous() * Shape::kContiguous * stride_ + + coord.strided() * Shape::kStrided * Layout::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCrosswise128x4, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise128x4, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCrosswise128x4; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise128x4, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..be07e43f6f45132f79d95afb95714c4392149b66 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing storing of tiles from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int Alignment = sizeof_bits::value * ThreadMap::kElementsPerAccess / 8 +> +class RegularTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h new file mode 100644 index 0000000000000000000000000000000000000000..6c186ce3fe0650c3f8927d84f1983916d9d1867f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h @@ -0,0 +1,552 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile + first, with the objective of minimizing predicate mask updates during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Regular tile iterator specialized for pitch-linear. This one is used by 2-stage SIMT kernels +/// and sparse tensor core meta data. +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + + using AccessType = AlignedArray; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the contiguous or strided dimensions."); + +private: + + // + // Types + // + + // + // Data members + // + + /// Pointer to memory + uint8_t *pointer_; + + /// Stride quantity + StrideIndex stride_; + + /// Amount to increment pointer along strided dimension + Index increment_strided_; + + /// Amount to advance pointer between tiles + Index increment_advance_; + +public: + + CUTLASS_DEVICE + RegularTileIterator(): pointer_(nullptr), increment_strided_(0), increment_advance_(0) { } + + CUTLASS_DEVICE + RegularTileIterator( + TensorRef const &ref, + int thread_idx + ): + pointer_(reinterpret_cast(ref.data()) + (ref.offset(ThreadMap::initial_offset(thread_idx)) * sizeof_bits::value / 8)) { + + stride_ = ref.stride()[0]; + increment_strided_ = (ref.stride()[0] * sizeof_bits::value) * ThreadMap::Delta::kStrided / 8; + + increment_advance_ = + (kAdvanceRank == 0 ? + Shape::kContiguous * sizeof_bits::value / 8 : + Shape::kStrided * (ref.stride()[0] * sizeof_bits::value / 8)); + } + + /// Loads a fragment + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + uint8_t const *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType const *access_ptr = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int idx = c + s * ThreadMap::Iterations::kContiguous; + frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess]; + } + + if (s + 1 < ThreadMap::Iterations::kStrided) { + byte_pointer += increment_strided_; + } + } + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + load_with_pointer_offset( + frag, + tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + + tile_offset.strided() * Shape::kStrided * stride_ + ); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + uint8_t *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int idx = c + s * ThreadMap::Iterations::kContiguous; + access_ptr[c * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess] = frag_ptr[idx]; + } + + if (s + 1 < ThreadMap::Iterations::kStrided) { + byte_pointer += increment_strided_; + } + } + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + store_with_pointer_offset( + frag, + tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_ + ); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + pointer_ += increment_advance_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator--() { + pointer_ -= increment_advance_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + int offset = sizeof_bits::value * + (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; + add_pointer_offset(offset); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { +#if 0 + AccessType *access_ptr = pointer_[iteration_strided_ & 1]; + int stride_idx = (iteration_strided_ & ~1); + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + return reinterpret_cast(access_byte_ptr + byte_offset_); +#endif + return reinterpret_cast(pointer_); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Regular tile iterator specialized for row major +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + + using Underlying = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + kAlignment + >; + + using AccessType = typename Underlying::AccessType; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the row or column dimensions."); + +private: + + Underlying iterator_; + +public: + + CUTLASS_DEVICE + RegularTileIterator() { } + + CUTLASS_DEVICE + RegularTileIterator( + TensorRef const &ref, + int thread_idx + ): + iterator_({ref.data(), ref.stride()}, thread_idx) { + + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + iterator_.load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + iterator_.store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator--() { + --iterator_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return iterator_.get(); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Regular tile iterator specialized for pitch-linear +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + + using Underlying = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap + >; + + using AccessType = typename Underlying::AccessType; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the row or column dimensions."); + +private: + + Underlying iterator_; + +public: + + CUTLASS_DEVICE + RegularTileIterator() { } + + CUTLASS_DEVICE + RegularTileIterator( + TensorRef const &ref, + int thread_idx + ): + iterator_({ref.data(), ref.stride()}, thread_idx) { + + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + iterator_.load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + iterator_.store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator--() { + --iterator_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return iterator_.get(); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h new file mode 100644 index 0000000000000000000000000000000000000000..5ed2e7fdd08ceafe772c97ab90f915c2268cabbb --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h @@ -0,0 +1,509 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile + first, with the objective of minimizing predicate mask updates during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int Alignment = sizeof_bits::value * ThreadMap::kElementsPerAccess / 8 +> +class RegularTileIterator2dThreadTile; + + +/// Regular tile iterator specialized for pitch-linear + 2d thread-tiled threadmapping +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator2dThreadTile { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the contiguous or strided dimensions."); + +private: + + // + // Types + // + + using AccessType = AlignedArray; + + // + // Data members + // + + /// Pointer to memory + uint8_t *pointer_; + + /// Stride quantity + StrideIndex stride_; + + /// Amount to increment pointer along strided dimension + LongIndex increment_strided_; + + /// Amount to advance pointer between tiles + LongIndex increment_advance_; + +public: + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile(): pointer_(nullptr), increment_strided_(0), increment_advance_(0) { } + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile( + TensorRef const &ref, + int thread_idx, + int interleave + ){ + + TensorCoord t = ThreadMap::initial_offset(thread_idx); + long int offset = t[0] * interleave + t[1] * ref.stride()[0]/interleave; + pointer_ = reinterpret_cast(ref.data() + offset); + + stride_ = ref.stride()[0] / interleave; + increment_strided_ = (ref.stride()[0] * sizeof_bits::value / 8) * ThreadMap::Delta::kStrided / interleave; + + increment_advance_ = + (kAdvanceRank == 0 ? + Shape::kContiguous * sizeof_bits::value / 8 : + Shape::kStrided * (ref.stride()[0] * sizeof_bits::value / 8) / interleave); + } + + /// Loads a fragment + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + uint8_t const *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType const *access_ptr = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int idx = c + s * ThreadMap::Iterations::kContiguous; + frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided]; + } + + if (s + 1 < ThreadMap::Iterations::kStrided) { + byte_pointer += increment_strided_; + } + } + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + load_with_pointer_offset( + frag, + tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + + tile_offset.strided() * Shape::kStrided * stride_ + ); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + uint8_t *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int idx = c + s * ThreadMap::Iterations::kContiguous; + access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided] = frag_ptr[idx]; + } + + if (s + 1 < ThreadMap::Iterations::kStrided) { + byte_pointer += increment_strided_; + } + } + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + store_with_pointer_offset( + frag, + tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_ + ); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator++() { + pointer_ += increment_advance_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator--() { + pointer_ -= increment_advance_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + int offset = sizeof_bits::value * + (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; + add_pointer_offset(offset); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Regular tile iterator specialized for interleaved layout + 2d thread-tiled threadmapping +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator2dThreadTile, AdvanceRank, ThreadMap_, Alignment> { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorInterleaved<4>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + + using Underlying = RegularTileIterator2dThreadTile< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + kAlignment + >; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the row or column dimensions."); + +private: + + Underlying iterator_; + +public: + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile() { } + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile( + TensorRef const &ref, + int thread_idx + ): + iterator_({ref.data(), ref.stride()}, thread_idx, 4) { + + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + iterator_.load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + iterator_.store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator--() { + --iterator_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Regular tile iterator specialized for interleaved layout + 2d thread-tiled threadmapping +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator2dThreadTile, AdvanceRank, ThreadMap_, Alignment> { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorInterleaved<4>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + using PitchLinearThreadMap = PitchLinearStripminedThreadMap< layout::PitchLinearShape, + ThreadMap::kThreads, ThreadMap::ThreadAccessShape::kCount >; + + + using Underlying = RegularTileIterator2dThreadTile< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap + >; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the row or column dimensions."); + +private: + + Underlying iterator_; + +public: + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile() { } + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile( + TensorRef const &ref, + int thread_idx + ): + iterator_({ref.data(), ref.stride()}, thread_idx, 4) { + + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + iterator_.load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + iterator_.store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator--() { + --iterator_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..723f328d976fc170d198282823e3da6876ec1ba6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h @@ -0,0 +1,1107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing storing of tiles from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandCongruous::value, + Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + + /// This iterator is specialized for an access size that is 128 bits in length. + static int const kAccessSizeInBits = 128; + + static_assert( + sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + }; + +private: + + /// Element type per access + using AccessType = Array; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = RegularTileAccessIterator; + +private: + + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : address_iterator_(ref, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + address_iterator_.add_tile_offset({0, 1}); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + address_iterator_.add_tile_offset(coord); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, Index byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + frag_ptr[access_idx] = *access_ptr; + ++address_iterator_; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, Index byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + *access_ptr = frag_ptr[access_idx]; + ++address_iterator_; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator< + Shape_, Element_, + layout::RowMajorTensorOpMultiplicandCongruous::value, + Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for crosswise arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>; + + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + }; + + private: + /// Element type per access + using AccessType = Array; + + public: + /// Fragment object to be loaded or stored + using Fragment = + Array; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = RegularTileAccessIterator; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : address_iterator_(ref, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + address_iterator_.add_tile_offset({1, 0}); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + address_iterator_.add_tile_offset(coord); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + address_iterator_.set_iteration_index(0); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); + ++address_iterator_; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, Index byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + *access_ptr = frag_ptr[access_idx]; + ++address_iterator_; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for k interleaved arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator< + Shape_, Element_, + layout::TensorOpMultiplicandRowMajorInterleaved::value, + InterleavedK>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandRowMajorInterleaved::value, + InterleavedK>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + }; + + private: + + /// Element type per access + using AccessType = Array; + + public: + /// Fragment object to be loaded or stored + using Fragment = + Array; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = RegularTileAccessIterator; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : address_iterator_(ref, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + address_iterator_.add_pointer_offset(Shape::kCount); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + address_iterator_.add_pointer_offset(coord.contiguous() * Shape::kCount); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + address_iterator_.set_iteration_index(0); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); + ++address_iterator_; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; + ++address_iterator_; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for k interleaved arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// + +template +class RegularTileIterator< + Shape_, Element_, + layout::TensorOpMultiplicandColumnMajorInterleaved::value, + InterleavedK>, + AdvanceRank, ThreadMap_, Alignment> { + + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandColumnMajorInterleaved::value, + InterleavedK>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + cutlass::MatrixShape, + Element, + layout::TensorOpMultiplicandRowMajorInterleaved::value, InterleavedK>, + (kAdvanceRank == 1 ? 0 : 1), + ThreadMap + >; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.strided(), coord.contiguous()}); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h new file mode 100644 index 0000000000000000000000000000000000000000..53121c6114cc3675e4d97f9da65d3ecb58e46d62 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h @@ -0,0 +1,1460 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile + first, with the objective of minimizing predicate mask updates during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm70.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::VoltaTensorOpMultiplicandCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::VoltaTensorOpMultiplicandCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + + /// This iterator is specialized for an access size that is 128 bits in length. + static int const kAccessSizeInBits = 128; + + static_assert( + sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + + ///< Number of pointers + static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); + }; + + +private: + + /// Element type per access + using AccessType = Array; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType * pointer_[Detail::kPointerCount]; + + /// Internal byte offset + Index byte_offset_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + + // This is the offset of a thread within a threadblock tile for a specific pointer + // (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + thread_offset_base + layout::PitchLinearCoord{0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; + + // initialize pointer + pointer_[i] = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + add_pointer_offset((kAdvanceRank ? Shape::kStrided * stride_ * Layout::kElementsPerAccess : Shape::kContiguous)); + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset( + coord.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + + coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess + ); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = pointer_[s & 1]; + int stride_idx = (s & ~1); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + + vec_pointer_offset; + + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + frag_ptr[access_idx] = *reinterpret_cast(access_byte_ptr + byte_offset_); + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = pointer_[s & 1]; + int stride_idx = (s & ~1); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + + vec_pointer_offset; + + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + *reinterpret_cast(access_byte_ptr + byte_offset_) = frag_ptr[access_idx]; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::VoltaTensorOpMultiplicandCongruous::value>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap_>; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::RowMajorVoltaTensorOpMultiplicandCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorVoltaTensorOpMultiplicandCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::VoltaTensorOpMultiplicandCongruous::value>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap_>; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::VoltaTensorOpMultiplicandBCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::VoltaTensorOpMultiplicandBCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + + /// This iterator is specialized for an access size that is 128 bits in length. + static int const kAccessSizeInBits = 128; + + static_assert( + sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + + ///< Number of pointers + static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); + }; + + +private: + + /// Element type per access + using AccessType = Array; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType * pointer_[Detail::kPointerCount]; + + /// Internal byte offset + Index byte_offset_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + + // This is the offset of a thread within a threadblock tile for a specific pointer + // (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + thread_offset_base + layout::PitchLinearCoord{0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; + + // initialize pointer + pointer_[i] = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + add_pointer_offset((kAdvanceRank ? Shape::kStrided * stride_ * Layout::kElementsPerAccess : Shape::kContiguous)); + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset( + coord.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + + coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess + ); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = pointer_[s & 1]; + int stride_idx = (s & ~1); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + + vec_pointer_offset; + + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + frag_ptr[access_idx] = *reinterpret_cast(access_byte_ptr + byte_offset_); + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = pointer_[s & 1]; + int stride_idx = (s & ~1); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + + vec_pointer_offset; + + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + *reinterpret_cast(access_byte_ptr + byte_offset_) = frag_ptr[access_idx]; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::ColumnMajorVoltaTensorOpMultiplicandBCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandBCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::VoltaTensorOpMultiplicandBCongruous::value>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap_>; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::VoltaTensorOpMultiplicandBCongruous::value>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap_>; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + + +/// Tile iterator specialized for crosswise arrangements for TensorOps. +/// +/// Volta TN SMEM layout is a little diffrent: +/// Crosseised elements will be stored in a line, while contiguous elements +/// sre stored in line-by-line. +/// Padding is used to reduce SMEM bank conflicts. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, Element_, + layout::VoltaTensorOpMultiplicandCrosswise::value, + Shape_::kContiguous>, + AdvanceRank, ThreadMap_, Alignment> { + + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::VoltaTensorOpMultiplicandCrosswise::value, + Shape::kContiguous>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + + ///< Number of pointers + static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); + + /// Iterations for the kElementsPerAccess of ThreadMap + static int const kIterarionsPerAccess = + ThreadMap::kElementsPerAccess / Layout::kElementsPerAccess; + + /// Contiguous elements per line + static int const kContiguousElementsPerLine = 4; + }; + + private: + /// Element type per access + using AccessType = Array; + + public: + /// Fragment object to be loaded or stored + using Fragment = + Array; + + private: + // + // Data members + // + + /// The crosswised elements will be stored in a line. + /// line_size is size of crosswised dimension plus padding. + /// in units of AccessType + Index line_size; + + /// Internal pointer to first access of tile + AccessType *pointer_[Detail::kPointerCount]; + + /// Internal byte offset + Index byte_offset_; + + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : line_size(ref.stride(0) * Detail::kContiguousElementsPerLine / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = + ThreadMap::initial_offset(thread_id); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + thread_offset_base + + layout::PitchLinearCoord{ + 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; + + // initialize pointer + pointer_[i] = reinterpret_cast( + ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + // (Shape::kContiguous/Layout::kElementsPerAccess)* + // line_size * Layout::kElementsPerAccess + add_pointer_offset(Shape::kContiguous * line_size); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset((coord.contiguous() * (Shape::kContiguous / Layout::kElementsPerAccess) * + line_size + coord.strided() * Shape::kStrided) * + Layout::kElementsPerAccess); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + AccessType *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + AccessType *access_ptr = pointer_[(s & 1) ^ (s / 2)]; + + access_ptr += 16 * (s / 2); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { + + int access_offset = + c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + + vec_pointer_offset + i * line_size; + + int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * + Detail::kIterarionsPerAccess + i; + + char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + frag_ptr[access_idx] = *reinterpret_cast( + access_byte_ptr + byte_offset_); + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + AccessType const *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = pointer_[(s & 1) ^ ((s >> 1) & 1)]; + + access_ptr += 16 * (s / 2) + vec_pointer_offset; + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { + + int access_offset = + c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + i * line_size; + + int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * + Detail::kIterarionsPerAccess + i; + + char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + *reinterpret_cast(access_byte_ptr + byte_offset_) = + frag_ptr[access_idx]; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator::value, Shape_::kRow>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kRow>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::VoltaTensorOpMultiplicandCrosswise::value, + Shape::kRow>, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator::value, Shape_::kColumn>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kColumn>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::VoltaTensorOpMultiplicandCrosswise::value, + Shape::kColumn>, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/vector_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/vector_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..8e5d181c177b2ad6627c927ae4ad3fb9c99a96d3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/vector_iterator.h @@ -0,0 +1,149 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template wraps the vector access iterator concept to load whole vector from tensors in + memory. This is typically used for per-channel scale and bias in convolution kernels. +*/ + +#pragma once + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class VectorIterator { +public: + using VectorAccessIterator = VectorAccessIterator_; + + using Shape = typename VectorAccessIterator::Shape; + using Element = typename VectorAccessIterator::Element; + using Layout = typename VectorAccessIterator::Layout; + using TensorCoord = typename Layout::TensorCoord; + using AccessType = typename VectorAccessIterator::AccessType; + using TensorRef = typename VectorAccessIterator::TensorRef; + using Index = typename VectorAccessIterator::Index; + using LongIndex = typename VectorAccessIterator::LongIndex; + + static int const kElementsPerAccess = VectorAccessIterator::kElementsPerAccess; + static int const kRowsPerIteration = VectorAccessIterator::kRowsPerIteration; + static int const kThreads = VectorAccessIterator::kThreads; + static int const kIterations = VectorAccessIterator::kIterations; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, kElementsPerAccess * kIterations>; + +private: + + /// Internal state + VectorAccessIterator vector_access_iterator_; + +public: + + /// Constructor + CUTLASS_HOST_DEVICE + VectorIterator( + Element const *ptr, + TensorCoord extent, + int thread_idx, + int warp_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + vector_access_iterator_(ptr, extent, thread_idx, warp_idx, threadblock_offset) { } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + VectorIterator &operator++() { + vector_access_iterator_.advance(); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + VectorIterator operator++(int) { + VectorIterator self(*this); + operator++(); + return self; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.clear(); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[c], + vector_access_iterator_.get() + pointer_offset, + vector_access_iterator_.valid() + ); + + ++vector_access_iterator_; + } +// } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + vector_access_iterator_.set_iteration_index(0); + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void advance() { + vector_access_iterator_.advance(); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..b27b77f9b697476ed54a019cd94120561371ebd1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of a warp vector + that participate in one warp-level mma operation. + + Typically, this is used to access the scale/bias fragment of a warp-level mma operation. + The scale/bias vector is then partitioned into smaller fragments that can be fed into + next warp-level mma operation. + + This iterator is necessary to accomplish warp-level mma fusion where the scale/bias vector is + applied to the multiplicand for the next mma. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" + +namespace cutlass { +namespace transform { +namespace warp { + + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the input fragment tile shape (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + //// Number of elements per access when loading fragment + int ElementsPerAccess> +class VectorFragmentIterator; + + +// Partial specialization for PitchLinear layout tile + +template < + /// Size of the input fragment vector shape (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + //// Number of elements per access when loading fragment + int ElementsPerAccess> +class VectorFragmentIterator { + public: + + /// Size of the input threadblock tile shape (concept: MatrixShape) + using Shape = Shape_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::PitchLinear; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Number of participating threads + static int const kThreads = 32; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kRowsPerIteration = 8; + static int const kColumnsPerAccess = 8; + static int const kElementsPerIteration = kRowsPerIteration * InstructionShape::kK / kThreads; + static int const kAccessPerIteration = kElementsPerIteration / kElementsPerAccess; + + /// Number of iterations + using Iterations = MatrixShape; + +public: + + // + // Derived quantities + // + // All fragments have kElementsPerAccess scale followed by bias + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one iteration of the iterator. + using Fragment = Array; + + /// Input threadblock fragment tile + using ThreadblockFragment = Array; + +private: + + /// Internal access type + using AccessType = Array; + +private: + // + // Data members + // + + /// Input threadblock fragment tile + AccessType const *iterator_; + + /// Internal index + int index_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + VectorFragmentIterator(ThreadblockFragment const &threadblock_frag) + : iterator_(reinterpret_cast(&threadblock_frag)), + index_(0) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + + if(index_ >= Iterations::kColumn) + index_ = 0; + } + + /// Increments + CUTLASS_HOST_DEVICE + VectorFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + CUTLASS_HOST_DEVICE + void set_index(int idx) { + index_ = idx; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int r = 0; r < Iterations::kRow; r++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kAccessPerIteration; i++) { + + frag_ptr[i * Iterations::kRow + r].clear(); + frag_ptr[i * Iterations::kRow + r] = iterator_[index_ * kAccessPerIteration + i]; + } + } + } + +}; + +// Partial specialization for Row-Major layout tile + +template < + /// Size of the input fragment tile shape (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + //// Number of elements per access when loading fragment + int ElementsPerAccess> +class VectorFragmentIterator { + public: + + /// Size of the input threadblock tile shape (concept: MatrixShape) + using Shape = Shape_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Underlying iterator + using Base = VectorFragmentIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, InstructionShape, ElementsPerAccess>; + + + public: + + // + // Derived quantities + // + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one iteration of the iterator. + using Fragment = typename Base::Fragment; + + /// Input threadblock fragment tile + using ThreadblockFragment = typename Base::ThreadblockFragment; + + private: + /// Underlying iterator + Base iterator_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + VectorFragmentIterator(ThreadblockFragment const &threadblock_frag) + : iterator_(threadblock_frag) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + iterator_.add_offset(index_offset); + } + + /// Increments + CUTLASS_HOST_DEVICE + VectorFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + CUTLASS_HOST_DEVICE + void set_index(int idx) { + iterator_.set_index(idx); + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + iterator_.load(frag); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/uint128.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/uint128.h new file mode 100644 index 0000000000000000000000000000000000000000..68896d6b60767221fd41421a0d3fdf75392c3604 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/uint128.h @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines an unsigned 128b integer with several operators to support 64-bit integer division. +*/ +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#include +#include +#include +#include +#endif + + +/// Optionally enable GCC's built-in type +#if (defined(__x86_64) || defined (__aarch64__)) && !(defined(__CUDA_ARCH__) && ((__CUDACC_VER_MAJOR__ <= 10) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ <= 4)))) && defined(__GNUC__) +#define CUTLASS_UINT128_NATIVE +#elif !defined(__CUDA_ARCH__) +// No custom support for 128b arithmetic on device +#if defined(_MSC_VER) && defined(_M_AMD64) +#define CUTLASS_INT128_ARITHMETIC +#include +#if _MSC_VER >= 1920 && !defined(__CUDA_ARCH__) +#define CUTLASS_INT128_ARITHMETIC_DIV +#include +#endif +#endif +#endif + +namespace cutlass { + +///! Unsigned 128b integer type +struct alignas(16) uint128_t +{ + /// Size of one part of the uint's storage in bits + static constexpr int storage_bits_ = 64; + + struct hilo + { + uint64_t lo; + uint64_t hi; + }; + + // Use a union to store either low and high parts or, if present, a built-in 128b integer type. + union { + struct hilo hilo_; + +#if defined(CUTLASS_UINT128_NATIVE) + unsigned __int128 native; +#endif // defined(CUTLASS_UINT128_NATIVE) + }; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + uint128_t() : hilo_{0, 0} {} + + /// Constructor from uint64 + CUTLASS_HOST_DEVICE + uint128_t(uint64_t lo_) : hilo_{lo_, 0} {} + + /// Constructor from two 64b unsigned integers + CUTLASS_HOST_DEVICE + uint128_t(uint64_t lo_, uint64_t hi_) : hilo_{lo_, hi_} {} + + /// Optional constructor from native value +#if defined(CUTLASS_UINT128_NATIVE) + uint128_t(unsigned __int128 value) : native(value) { } +#endif + + /// Lossily cast to uint64 + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const + { + return hilo_.lo; + } + + CUTLASS_HOST_DEVICE + static void exception() + { +#if defined(__CUDA_ARCH__) + asm volatile (" brkpt;\n"); +#else + // throw std::runtime_error("Not yet implemented."); + abort(); +#endif + } + + /// Add + CUTLASS_HOST_DEVICE + uint128_t operator+(uint128_t const& rhs) const + { + uint128_t y{}; +#if defined(CUTLASS_UINT128_NATIVE) + y.native = native + rhs.native; +#else + y.hilo_.lo = hilo_.lo + rhs.hilo_.lo; + y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (y.hilo_.lo < hilo_.lo); +#endif + return y; + } + + /// Subtract + CUTLASS_HOST_DEVICE + uint128_t operator-(uint128_t const& rhs) const + { + uint128_t y{}; +#if defined(CUTLASS_UINT128_NATIVE) + y.native = native - rhs.native; +#else + y.hilo_.lo = hilo_.lo - rhs.hilo_.lo; + y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo); +#endif + return y; + } + + /// Multiply by unsigned 64b integer yielding 128b integer + CUTLASS_HOST_DEVICE + uint128_t operator*(uint64_t const& rhs) const + { + uint128_t y{}; +#if defined(CUTLASS_UINT128_NATIVE) + y.native = native * rhs; +#elif defined(CUTLASS_INT128_ARITHMETIC) + // Multiply by the low part + y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); + + // Add the high part and ignore the overflow + uint64_t overflow{0}; + y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); +#else + CUTLASS_UNUSED(rhs); + exception(); +#endif + return y; + } + + /// Divide 128b operation by 64b operation yielding a 64b quotient + CUTLASS_HOST_DEVICE + uint64_t operator/(uint64_t const& divisor) const + { + uint64_t quotient{0}; +#if defined(CUTLASS_UINT128_NATIVE) + quotient = uint64_t(native / divisor); +#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) + // implemented using MSVC's arithmetic intrinsics + uint64_t remainder{0}; + quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); +#else + CUTLASS_UNUSED(divisor); + exception(); +#endif + return quotient; + } + + /// Divide 128b operation by 64b operation yielding a 64b quotient + CUTLASS_HOST_DEVICE + uint64_t operator%(uint64_t const& divisor) const + { + uint64_t remainder{0}; +#if defined(CUTLASS_UINT128_NATIVE) + remainder = uint64_t(native % divisor); +#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) + // implemented using MSVC's arithmetic intrinsics + (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); +#else + CUTLASS_UNUSED(divisor); + exception(); +#endif + return remainder; + } + + /// Computes the quotient and remainder in a single method. + CUTLASS_HOST_DEVICE + uint64_t divmod(uint64_t &remainder, uint64_t divisor) const + { + uint64_t quotient{0}; +#if defined(CUTLASS_UINT128_NATIVE) + quotient = uint64_t(native / divisor); + remainder = uint64_t(native % divisor); +#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) + // implemented using MSVC's arithmetic intrinsics + quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); +#else + CUTLASS_UNUSED(remainder); + CUTLASS_UNUSED(divisor); + exception(); +#endif + return quotient; + } + + /// Left-shifts a 128b unsigned integer + CUTLASS_HOST_DEVICE + uint128_t operator<<(int sh) const + { + if (sh == 0) { + return *this; + } + else if (sh >= storage_bits_) { + return uint128_t(0, hilo_.lo << (sh - storage_bits_)); + } + else { + return uint128_t( + (hilo_.lo << sh), + (hilo_.hi << sh) | uint64_t(hilo_.lo >> (storage_bits_ - sh)) + ); + } + } + + /// Right-shifts a 128b unsigned integer + CUTLASS_HOST_DEVICE + uint128_t operator>>(int sh) const + { + if (sh == 0) { + return *this; + } + else if (sh >= storage_bits_) { + return uint128_t((hilo_.hi >> (sh - storage_bits_)), 0); + } + else { + return uint128_t( + (hilo_.lo >> sh) | (hilo_.hi << (storage_bits_ - sh)), + (hilo_.hi >> sh) + ); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/uint256.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/uint256.h new file mode 100644 index 0000000000000000000000000000000000000000..3657853557ebccfd6be63ce6ba0fa4d69880d649 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/uint256.h @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines an unsigned 256b integer. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#include +#include +#include +#include +#endif +#include "cutlass/uint128.h" + +namespace cutlass { + +///! Unsigned 256b integer type +struct alignas(32) uint256_t { + /// Size of one part of the uint's storage in bits + static constexpr int storage_bits_ = 128; + + struct hilo { + uint128_t lo; + uint128_t hi; + }; + + // Use a union to store either low and high parts. + union { + struct hilo hilo_; + }; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + uint256_t() : hilo_{uint128_t{}, uint128_t{}} {} + + /// Constructor from uint128 + CUTLASS_HOST_DEVICE + uint256_t(uint128_t lo_) : hilo_{lo_, uint128_t{}} {} + + /// Constructor from two 128b unsigned integers + CUTLASS_HOST_DEVICE + uint256_t(uint128_t lo_, uint128_t hi_) : hilo_{lo_, hi_} {} + + /// Lossily cast to uint128_t + CUTLASS_HOST_DEVICE + explicit operator uint128_t() const { + return hilo_.lo; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/version.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/version.h new file mode 100644 index 0000000000000000000000000000000000000000..57a73a5fbb41a22ed5e44743c84fa1bbbe0b0075 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/version.h @@ -0,0 +1,80 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#define CUTLASS_MAJOR 4 +#define CUTLASS_MINOR 2 +#define CUTLASS_PATCH 1 + +#ifdef CUTLASS_VERSIONS_GENERATED +#include "cutlass/version_extended.h" +#else +#define CUTLASS_BUILD 0 +#define CUTLASS_REVISION "" +#endif + +#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) + +namespace cutlass { + + inline constexpr uint32_t getVersion() { + return CUTLASS_VERSION; + } + inline constexpr uint32_t getVersionMajor() { + return CUTLASS_MAJOR; + } + inline constexpr uint32_t getVersionMinor() { + return CUTLASS_MINOR; + } + inline constexpr uint32_t getVersionPatch() { + return CUTLASS_PATCH; + } + inline constexpr uint32_t getVersionBuild() { + return CUTLASS_BUILD + 0; + } + + inline std::string getVersionString() { + std::string version = "@CUTLASS_VERSION@"; + if (getVersionBuild()) { + version += "." + std::to_string(getVersionBuild()); + } + return version; + } + + inline std::string getGitRevision() { + return "@CUTLASS_REVISION@"; + } + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/wmma_array.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/wmma_array.h new file mode 100644 index 0000000000000000000000000000000000000000..77929f60f73dc07ea2a8e47de1cfb95b5f8859f0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/wmma_array.h @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types + and is safe to use in a union. +*/ + +#pragma once + +#include "cutlass/arch/wmma.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Wmma array type (WmmaFragmentArray holds elements of type nvcuda::wmma::fragment) +template < + /// Element type + typename T, + /// Number of elements in the array + int N, + /// Whether the element type of T is half_t or __half + bool IsHalfType = (platform::is_same::value || + platform::is_same::value) +> +class WmmaFragmentArray: public Array { +public: + + /// Efficient clear method (override Array::clear()) + CUTLASS_HOST_DEVICE + void clear() + { + for(int i = 0; i < Array::kElements; i++) + { + nvcuda::wmma::fill_fragment((*this)[i], (typename T::element_type)0); + } + } + + CUTLASS_HOST_DEVICE + WmmaFragmentArray& operator+=(const WmmaFragmentArray& rhs) + { + using element_type = typename T::element_type; + plus add; + + for (int i = 0; i < Array::kElements; i++) + { + (*this)[i] = add((*this)[i], rhs[i]); + } + + return *this; + } +}; + +/// Partial specialization for the case in which T::element_type is +/// half_t or __half. This is needed because the cast (typename T::element_type)0 +/// in the primary template flags as an error when __CUDA_NO_HALF_CONVERSIONS__ +/// is set. +template < + /// Element type + typename T, + /// Number of elements in the array + int N +> +class WmmaFragmentArray: public Array { +public: + + /// Efficient clear method (override Array::clear()) + CUTLASS_HOST_DEVICE + void clear() + { + for(int i = 0; i < Array::kElements; i++) + { + nvcuda::wmma::fill_fragment((*this)[i], __float2half(0.f)); + } + } + + CUTLASS_HOST_DEVICE + WmmaFragmentArray& operator+=(const WmmaFragmentArray& rhs) + { + using element_type = typename T::element_type; + plus add; + + for (int i = 0; i < Array::kElements; i++) + { + (*this)[i] = add((*this)[i], rhs[i]); + } + + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/workspace.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/workspace.h new file mode 100644 index 0000000000000000000000000000000000000000..485ebbe3ae27af7ddc05bc1e36f32b1a4ee65901 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/workspace.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for initializing workspaces +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include "cuda.h" +#include "cuda_runtime.h" + +#include "cutlass/trace.h" +#endif + +#include "cutlass.h" +#include "cutlass/cuda_host_adapter.hpp" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr int MinWorkspaceAlignment = 16; + +#if !defined(__CUDACC_RTC__) +static Status +zero_workspace( + void* workspace, + size_t workspace_size, + cudaStream_t stream = nullptr, + [[maybe_unused]] CudaHostAdapter *cuda_adapter = nullptr) { + if (workspace_size > 0) { + if (workspace == nullptr) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + return Status::kErrorWorkspaceNull; + } + + CUTLASS_TRACE_HOST(" clearing workspace"); + +#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast(0), workspace_size, stream)) { + return Status::kErrorInternal; + } + } + else { + return Status::kErrorInternal; + } +#else + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_size, stream); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } +#endif + } + + return Status::kSuccess; +} +#endif + +#if !defined(__CUDACC_RTC__) +template +Status +fill_workspace(void* workspace, T fill_value, size_t fill_count, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + static_assert(sizeof(T) == 4 || sizeof(T) == 2 || sizeof(T) == 1, "Unsupported fill type"); + if (fill_count > 0) { + if (workspace == nullptr) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + return Status::kErrorWorkspaceNull; + } + + CUTLASS_TRACE_HOST(" filling workspace"); + +#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, fill_value, fill_count, stream)) { + return Status::kErrorInternal; + } + } + else { + return Status::kErrorInternal; + } +#else + CUdeviceptr d_workspace = reinterpret_cast(workspace); + CUresult result = CUDA_SUCCESS; + if (sizeof(T) == 4) { + result = cuMemsetD32Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); + } + else if (sizeof(T) == 2) { + result = cuMemsetD16Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); + } + else if (sizeof(T) == 1) { + result = cuMemsetD8Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); + } + + if (CUDA_SUCCESS != result) { + const char** error_string_ptr = nullptr; + (void) cuGetErrorString(result, error_string_ptr); + if (error_string_ptr != nullptr) { + CUTLASS_TRACE_HOST(" cuMemsetD" << sizeof(T) * 8 << "Async() returned error " << *error_string_ptr); + } + else { + CUTLASS_TRACE_HOST(" cuMemsetD" << sizeof(T) * 8 << "Async() returned unrecognized error"); + } + return Status::kErrorInternal; + } +#endif + } + + return Status::kSuccess; +} +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb617dc20d35f6dd352a84c3964a58fa9bc687e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +# Local module imports +from .dsl import * +from .runtime import * +from ._mlir_helpers import lru_cache_ir +from .env_manager import get_str_env_var, detect_gpu_arch + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..607a24d032c6ef899b586a41d2bb771c381406b0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides MLIR Dialect helper functions +""" + +from . import arith +from .lru_cache_ir import lru_cache_ir + + +__all__ = ["arith", "lru_cache_ir"] + +try: + from . import gpu + + __all__.extend(["gpu"]) +except ImportError: + pass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py new file mode 100644 index 0000000000000000000000000000000000000000..60cc8db31fd7369d721f3d7c64c5bb8fb03502a8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py @@ -0,0 +1,691 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides MLIR Arith Dialect helper functions +""" + +import array +import numpy as np + +from ..common import * +from ..._mlir import ir # type: ignore +from ..._mlir.extras import types as T # type: ignore +from ..._mlir.dialects import arith, nvgpu, math, builtin # type: ignore + +from .lru_cache_ir import lru_cache_ir + +# ============================================================================= +# Arith Dialect Helper functions +# ============================================================================= + + +def recast_type(src_type, res_elem_type) -> ir.Type: + if isinstance(src_type, T.VectorType): + if src_type.scalable: + res_type = T.vector( + *src_type.shape, + res_elem_type, + scalable=src_type.scalable, + scalable_dims=src_type.scalable_dims, + ) + else: + res_type = T.vector(*src_type.shape, res_elem_type) + elif isinstance(src_type, T.RankedTensorType): + res_type = T.RankedTensorType.get( + element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides + ) + elif isinstance(src_type, T.UnrankedTensorType): + res_type = T.UnrankedTensorType.get(element_type=res_elem_type) + elif isinstance(src_type, T.MemRefType): + res_type = T.MemRefType.get( + element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides + ) + else: + res_type = res_elem_type + return res_type + + +def is_scalar(ty) -> bool: + return not isinstance( + ty, (T.VectorType, T.RankedTensorType, T.UnrankedTensorType, T.MemRefType) + ) + + +def element_type(ty) -> ir.Type: + if not is_scalar(ty): + return ty.element_type + else: + return ty + + +def is_narrow_precision(ty) -> bool: + narrow_types = { + T.f8E8M0FNU(), + T.f8E4M3FN(), + T.f8E4M3(), + T.f8E5M2(), + T.f8E4M3B11FNUZ(), + T.f4E2M1FN(), + T.f6E3M2FN(), + T.f6E2M3FN(), + } + return ty in narrow_types + + +def is_float_type(ty) -> bool: + return ( + arith._is_float_type(ty) + # TODO-upstream: prediction is not correct. Patch here and fix in upstream later + or is_narrow_precision(ty) + or ty in (T.bf16(), T.tf32()) + ) + + +def truncf_to_narrow(res_ty, src, loc, ip): + res_elem_ty = element_type(res_ty) + if res_elem_ty == T.f8E8M0FNU(): + rnd = nvgpu.RoundingMode.RP + else: + rnd = nvgpu.RoundingMode.RN + return nvgpu.cvt_fptrunc(res_ty, src, rnd=rnd, loc=loc, ip=ip) + + +def extf_from_narrow(res_ty, src, loc, ip): + src_elem_ty = element_type(src.type) + + # When source type is E8M0, temporary element type has to be bf16 + tmp_elem_ty = T.bf16() if src_elem_ty == T.f8E8M0FNU() else T.f16() + tmp_ty = recast_type(src.type, tmp_elem_ty) + + # narrow -> bf16/f16 -> target type + tmp = nvgpu.cvt_fpext(tmp_ty, src, loc=loc, ip=ip) + return arith.extf(res_ty, tmp, loc=loc, ip=ip) + + +def bitcast(src, res_elem_type, *, loc=None, ip=None): + res_type = recast_type(src.type, res_elem_type) + return arith.bitcast(res_type, src, loc=loc, ip=ip) + + +def cvtf(src, res_elem_type, *, loc=None, ip=None): + src_elem_type = element_type(src.type) + + if res_elem_type == src_elem_type: + return src + + res_type = recast_type(src.type, res_elem_type) + + # Treat TF32 as F32 and use i32 as intermediate data + # TODO-upstream: update arith to support tf32 <-> f32 conversion + if src_elem_type == T.tf32(): + # tf32 -> i32 + tmp_type = recast_type(src.type, T.i32()) + src = builtin.unrealized_conversion_cast([tmp_type], [src], loc=loc, ip=ip) + # i32 -> f32 + src = bitcast(src, T.f32(), loc=loc, ip=ip) + # f32 -> X with `cvtf` recursively + return cvtf(src, res_elem_type, loc=loc, ip=ip) + + if res_elem_type == T.tf32(): + # X -> f32 with `cvtf`` recursively + tmp = cvtf(src, T.f32(), loc=loc, ip=ip) + # f32 -> i32 + tmp = bitcast(tmp, T.i32(), loc=loc, ip=ip) + # i32 -> tf32 + return builtin.unrealized_conversion_cast([res_type], [tmp], loc=loc, ip=ip) + + if res_elem_type.width > src_elem_type.width: + if is_narrow_precision(src_elem_type): + return extf_from_narrow(res_type, src, loc, ip) + else: + return arith.extf(res_type, src, loc=loc, ip=ip) + else: + tmp_mlir_type = recast_type(src.type, T.f32()) + + # f16 -- extf -> f32 -- truncf -> bf16 + # TODO-upstream: update arith to support bf16 <-> f16 conversion? + if (src_elem_type == T.f16() and res_elem_type == T.bf16()) or ( + src_elem_type == T.bf16() and res_elem_type == T.f16() + ): + tmp = arith.extf(tmp_mlir_type, src, loc=loc, ip=ip) + return arith.truncf(res_type, tmp, loc=loc, ip=ip) + + # {f8, f6, f4} -> f16, f32, ... + elif is_narrow_precision(res_elem_type): + return truncf_to_narrow(res_type, src, loc, ip) + else: + return arith.truncf(res_type, src, loc=loc, ip=ip) + + +def fptoi(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None): + res_type = recast_type(src.type, res_elem_type) + # TODO-upstream: update arith to support this kind of conversion + if element_type(src.type) in (T.tf32(), T.bf16()): + src = cvtf(src, T.f32(), loc=loc, ip=ip) + + if signed: + return arith.fptosi(res_type, src, loc=loc, ip=ip) + else: + return arith.fptoui(res_type, src, loc=loc, ip=ip) + + +def itofp(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None): + res_type = recast_type(src.type, res_elem_type) + + orig_res_type = res_type + # TODO-upstream: update arith to support this kind of conversion + if res_elem_type in (T.tf32(), T.bf16()): + res_type = recast_type(src.type, T.f32()) + + if signed and element_type(src.type).width > 1: + res = arith.sitofp(res_type, src, loc=loc, ip=ip) + else: + res = arith.uitofp(res_type, src, loc=loc, ip=ip) + + if orig_res_type == res_type: + return res + + return cvtf(res, element_type(orig_res_type), loc=loc, ip=ip) + + +def int_to_int(a, dst_elem_type, *, loc=None, ip=None): + src_signed = a.signed + dst_signed = dst_elem_type.signed + src_width = element_type(a.type).width + dst_width = dst_elem_type.width + + dst_mlir_type = recast_type(a.type, dst_elem_type.mlir_type) + + if dst_width == src_width: + return a + elif src_signed != False and not dst_signed: + # Signed -> Unsigned + if dst_width > src_width: + return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) + else: + return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip) + elif src_signed == dst_signed: + # Same signedness + if dst_width > src_width: + if src_signed != False and src_width > 1: + return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip) + else: + return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) + else: + return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip) + else: + # Unsigned -> Signed + if dst_width > src_width: + return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) + else: + # For truncation from unsigned to signed, we need to handle overflow + # First truncate to the target width + trunc = arith.trunci(dst_mlir_type, a, loc=loc, ip=ip) + # Then reinterpret as signed + if dst_signed: + return arith.bitcast(dst_mlir_type, trunc, loc=loc, ip=ip) + return trunc + + +# ============================================================================= +# Arith Ops Emitter Helpers +# - assuming type of lhs and rhs match each other +# - op name matches python module operator +# ============================================================================= + + +def _cast(res_elem_ty, src, is_signed=None, *, loc=None, ip=None): + """ + This function provides simplified interface to upstream op builder + arith.truncf(T.vector(shape, new_type), src) + + is simplified as because it's element-wise op which can't change shape + arith.truncf(new_type, src) + """ + if isinstance(src, ir.Value): + src_ty = src.type + else: + src_ty = type(src).mlir_type + src = src.ir_value() + + src_elem_ty = element_type(src_ty) + + if src_elem_ty == res_elem_ty: + return src + elif is_float_type(src_elem_ty) and is_float_type(res_elem_ty): + # float-to-float + return cvtf(src, res_elem_ty, loc=loc, ip=ip) + elif arith._is_integer_like_type(src_elem_ty) and arith._is_integer_like_type( + res_elem_ty + ): + if src_elem_ty.width >= res_elem_ty.width: + cast_op = arith.trunci + else: + if is_signed: + cast_op = arith.extsi + else: + cast_op = arith.extui + + res_ty = recast_type(src_ty, res_elem_ty) + return cast_op(res_ty, src, loc=loc, ip=ip) + elif is_float_type(src_elem_ty) and arith._is_integer_like_type(res_elem_ty): + return fptoi(src, is_signed, res_elem_ty, loc=loc, ip=ip) + elif arith._is_integer_like_type(src_elem_ty) and is_float_type(res_elem_ty): + return itofp(src, is_signed, res_elem_ty, loc=loc, ip=ip) + else: + raise DSLRuntimeError( + f"cast from {src_elem_ty} to {res_elem_ty} is not supported" + ) + + +@lru_cache_ir() +def const(value, ty=None, *, loc=None, ip=None): + """ + Generates dynamic expression for constant values. + """ + from ..typing import Numeric, NumericMeta + from ..dsl import is_dynamic_expression, _numpy_type_to_mlir_type + + if isinstance(value, Numeric): + value = value.value + + # Early return + if is_dynamic_expression(value) and ( + value.type.isinstance(value.type) or T.bool().isinstance(value.type) + ): + return value + + # Assume type + if ty is None: + if isinstance(value, float): + ty = T.f32() + elif isinstance(value, bool): + ty = T.bool() + elif isinstance(value, int): + ty = T.i32() + elif isinstance(value, np.ndarray): + ty = T.vector(*value.shape, _numpy_type_to_mlir_type(value.dtype)) + value = array.array(value.dtype.kind, value.flatten().tolist()) + else: + raise DSLNotImplemented(f"{type(value)} is not supported") + elif isinstance(ty, NumericMeta): + ty = ty.mlir_type + elif isinstance(ty, ir.Type): + if ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty): + elem_ty = ty.element_type + if isinstance(elem_ty, ir.IntegerType): + attr = ir.IntegerAttr.get(elem_ty, value) + else: + attr = ir.FloatAttr.get(elem_ty, value) + value = ir.DenseElementsAttr.get_splat(ty, attr) + elif arith._is_float_type(ty) and isinstance(value, (bool, int)): + value = float(value) + elif arith._is_integer_like_type(ty) and isinstance(value, float): + value = int(value) + else: + raise DSLNotImplemented(f"type {ty} is not supported") + + return arith.constant(ty, value, loc=loc, ip=ip) + + +def _dispatch_to_rhs_r_op(op): + """Decorator that dispatches to the right-hand-side's reverse operation. + + If the other operand is not an ArithValue or is a subclass (more specific) + of ArithValue, this allows proper method resolution for binary operations. + """ + + def wrapper(self, other, **kwargs): + if not isinstance(other, ArithValue): + if not isinstance(other, (int, float, bool)): + # allows to call other.__rmul__ + return NotImplemented + + return op(self, other, **kwargs) + + return wrapper + + +def _binary_op(op): + """ + Decorator to check if the 'other' argument is an ArithValue. + If not, returns NotImplemented. + """ + + def wrapper(self, other, **kwargs): + # When reach this point, `self` must be cast to base `ArithValue` type + if isinstance(other, (int, float, bool)): + other = const(other, self.type).with_signedness(self.signed) + + # Call the original function + # If sub-class doesn't implement overloaded arithmetic, cast to base class + return op(self, other, **kwargs) + + return wrapper + + +# Operator overloading +@ir.register_value_caster(ir.Float4E2M1FNType.static_typeid) +@ir.register_value_caster(ir.Float6E2M3FNType.static_typeid) +@ir.register_value_caster(ir.Float6E3M2FNType.static_typeid) +@ir.register_value_caster(ir.Float8E4M3FNType.static_typeid) +@ir.register_value_caster(ir.Float8E4M3B11FNUZType.static_typeid) +@ir.register_value_caster(ir.Float8E5M2Type.static_typeid) +@ir.register_value_caster(ir.Float8E4M3Type.static_typeid) +@ir.register_value_caster(ir.Float8E8M0FNUType.static_typeid) +@ir.register_value_caster(ir.BF16Type.static_typeid) +@ir.register_value_caster(ir.F16Type.static_typeid) +@ir.register_value_caster(ir.FloatTF32Type.static_typeid) +@ir.register_value_caster(ir.F32Type.static_typeid) +@ir.register_value_caster(ir.F64Type.static_typeid) +@ir.register_value_caster(ir.IntegerType.static_typeid) +@ir.register_value_caster(ir.VectorType.static_typeid) +@ir.register_value_caster(ir.RankedTensorType.static_typeid) +class ArithValue(ir.Value): + """Overloads operators for MLIR's Arith dialects binary operations.""" + + def __init__(self, v, signed: Union[bool, None] = None): + if isinstance(v, int): + v = arith.constant(self.type, v) + super().__init__(v) + + elem_ty = element_type(self.type) + self.is_float = arith._is_float_type(elem_ty) + # arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL + self.signed = signed and elem_ty.width > 1 + + def with_signedness(self, signed: Union[bool, None]): + return type(self)(self, signed) + + def __neg__(self, *, loc=None, ip=None): + if self.type == T.bool(): + raise TypeError( + "Negation, the operator `-` is not supported for boolean type" + ) + + if self.is_float: + return arith.negf(self, loc=loc, ip=ip) + else: + c0 = arith.constant(self.type, 0, loc=loc, ip=ip) + return arith.subi(c0, self, loc=loc, ip=ip) + + @_binary_op + def __pow__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float and other.is_float: + return math.powf(self, other, loc=loc, ip=ip) + elif self.is_float and not other.is_float: + return math.fpowi(self, other, loc=loc, ip=ip) + elif not self.is_float and other.is_float: + lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip) + rhs = cvtf(other, T.f32(), loc=loc, ip=ip) + return math.powf(lhs, rhs, loc=loc, ip=ip) + elif not self.is_float and not other.is_float: + return math.ipowi(self, other, loc=loc, ip=ip) + else: + raise DSLNotImplemented(f"Unsupported '{self} ** {other}'") + + @_binary_op + def __rpow__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__pow__(self, loc=loc, ip=ip) + + # arith operators + + @_dispatch_to_rhs_r_op + @_binary_op + def __add__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.addf(self, other, loc=loc, ip=ip) + else: + return arith.addi(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __sub__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.subf(self, other, loc=loc, ip=ip) + else: + return arith.subi(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __mul__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.mulf(self, other, loc=loc, ip=ip) + else: + return arith.muli(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __truediv__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.divf(self, other, loc=loc, ip=ip) + else: + lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip) + rhs = itofp(other, other.signed, T.f32(), loc=loc, ip=ip) + return arith.divf(lhs, rhs, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __floordiv__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + q = arith.divf(self, other, loc=loc, ip=ip) + return math.floor(q, loc=loc, ip=ip) + elif self.signed != False: + return arith.floordivsi(self, other, loc=loc, ip=ip) + else: + return arith.divui(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.remf(self, other, loc=loc, ip=ip) + elif self.signed != False: + return arith.remsi(self, other, loc=loc, ip=ip) + else: + return arith.remui(self, other, loc=loc, ip=ip) + + @_binary_op + def __radd__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__add__(self, loc=loc, ip=ip) + + @_binary_op + def __rsub__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__sub__(self, loc=loc, ip=ip) + + @_binary_op + def __rmul__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__mul__(self, loc=loc, ip=ip) + + @_binary_op + def __rtruediv__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__truediv__(self, loc=loc, ip=ip) + + @_binary_op + def __rfloordiv__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__floordiv__(self, loc=loc, ip=ip) + + @_binary_op + def __rmod__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__mod__(self, loc=loc, ip=ip) + + # Comparison operators (comparison doesn't have right-hand-side variants) + @_dispatch_to_rhs_r_op + @_binary_op + def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip) + elif self.signed != False: + return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __le__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip) + elif self.signed != False: + return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __eq__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.cmpf(arith.CmpFPredicate.OEQ, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.eq, self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __ne__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + # In Python, bool(float("nan")) is True, so use unordered comparison here + return arith.cmpf(arith.CmpFPredicate.UNE, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.ne, self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip) + elif self.signed != False: + return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip) + elif self.signed != False: + return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip) + + # Unary operators + def __invert__(self, *, loc=None, ip=None) -> "ArithValue": + return arith.xori(self, arith.constant(self.type, -1)) + + # Bitwise operations + @_dispatch_to_rhs_r_op + @_binary_op + def __and__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.andi(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __or__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.ori(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __xor__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.xori(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.signed != False: + return arith.shrsi(self, other, loc=loc, ip=ip) + else: + return arith.shrui(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __lshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.shli(self, other, loc=loc, ip=ip) + + @_binary_op + def __rand__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.andi(other, self, loc=loc, ip=ip) + + @_binary_op + def __ror__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.ori(other, self, loc=loc, ip=ip) + + @_binary_op + def __rxor__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.xori(other, self, loc=loc, ip=ip) + + @_binary_op + def __rrshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__rshift__(self, loc=loc, ip=ip) + + @_binary_op + def __rlshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__lshift__(self, loc=loc, ip=ip) + + def __hash__(self): + return super().__hash__() + + def __str__(self): + return "?" + + def __repr__(self): + return self.__str__() + + +def _min(lhs, rhs, *, loc=None, ip=None): + """ + This function provides a unified interface for building arith min + + Assuming the operands have the same type + """ + from ..dsl import is_dynamic_expression + + if not is_dynamic_expression(lhs): + if not is_dynamic_expression(rhs): + return min(lhs, rhs) + else: + lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip) + else: + if not is_dynamic_expression(rhs): + rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip) + + if arith._is_integer_like_type(lhs.type): + if lhs.signed != False: + return arith.minsi(lhs, rhs, loc=loc, ip=ip) + else: + return arith.minui(lhs, rhs, loc=loc, ip=ip) + else: + return arith.minimumf(lhs, rhs, loc=loc, ip=ip) + + +def _max(lhs, rhs, *, loc=None, ip=None): + """ + This function provides a unified interface for building arith max + + Assuming the operands have the same type + """ + from ..dsl import is_dynamic_expression + + if not is_dynamic_expression(lhs): + if not is_dynamic_expression(rhs): + return max(lhs, rhs) + else: + lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip) + else: + if not is_dynamic_expression(rhs): + rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip) + + if arith._is_integer_like_type(lhs.type): + if lhs.signed != False: + return arith.maxsi(lhs, rhs, loc=loc, ip=ip) + else: + return arith.maxui(lhs, rhs, loc=loc, ip=ip) + else: + return arith.maximumf(lhs, rhs, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b0d0500824f3c5ffc9ae51c7218f40c64b780c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides MLIR GPU Dialect helper functions +""" + + +from ..._mlir import ir +from ..._mlir.dialects import gpu, arith, scf +from ..._mlir.extras import types as T + +from ..common import * + +# ============================================================================= +# GPU Dialect Helper functions +# ============================================================================= + + +def create_async_token(): + token_ty = gpu.AsyncTokenType.get() + token = gpu.wait(token_ty, []) + return token + + +def printf(fmt, *args, threadNumber=-1): + """Generate gpu.printf OP predicated on threadNumber""" + type_formats = [] + for arg in args: + ty_format = None + if ir.IndexType.isinstance(arg.type): + ty_format = "%llu" + if ir.IntegerType.isinstance(arg.type): + width = ir.IntegerType(arg.type).width + if width == 64: + ty_format = "%llu" + elif width == 32: + ty_format = "%d" + elif width == 1: + ty_format = "%i" + if ir.F32Type.isinstance(arg.type): + ty_format = "%f" + if ty_format is None: + raise DSLNotImplemented(arg.type) + type_formats.append(ty_format) + if threadNumber == -1: + gpu.printf(fmt.format(*type_formats) + "\n", args) + if threadNumber != -1: + tidx = gpu.thread_id(gpu.Dimension.x) + predicate = arith.cmpi( + arith.CmpIPredicate.eq, tidx, arith.constant(_T.index(), threadNumber) + ) + if_op = scf.IfOp(predicate) + with ir.InsertionPoint(if_op.then_block): + gpu.printf(fmt.format(*type_formats) + "\n", args) + scf.yield_([]) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..57d717b42f94cfab678e70eceb5cc4d30dd10a45 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides @lru_cache_ir +It extends functools.lru_cache with IR Context awareness. + +Example usage: +from cutlass import ir +from lru_cache_ir import lru_cache_ir + +@lru_cache_ir(ir, maxsize=128, typed=False) +def make_layout(...): +... + +""" + + +from functools import lru_cache, wraps + +from ..._mlir import ir # type: ignore + + +def get_ir_context(func): + """ + Return the context for given func called under ir. + Currently the context includes MLIRContext and InsertionPoint. + """ + try: + if ir: + return (ir.Context.current, ir.InsertionPoint.current) + else: + return None + except ValueError: + return None + + +def lru_cache_ir(maxsize=128, typed=True): + """ + Applies an LRU cache to a given function, with awareness of IR context. + + Usage is similar to functools.lru_cache while taking `ir` as required argument. + + :param ir: The IR object from which to derive the context by `get_ir_context` + :param maxsize: Max cache size, same as functools.lru_cache + :param typed: Whether params are type-sensitive, default to True as IR is type-sensitive + """ + + def decorator(func): + # Use functools.lru_cache with a custom wrapper to control the key generation + @lru_cache(maxsize=maxsize, typed=typed) + def cached_func(context, *args, **kwargs): + return func(*args, **kwargs) + + @wraps(func) + def wrapper(*args, **kwargs): + try: + # Call the cached function with the context + return cached_func(get_ir_context(func), *args, **kwargs) + except (RuntimeError, TypeError): + return func(*args, **kwargs) + + # Expose cache-related methods for introspection + wrapper.cache_clear = cached_func.cache_clear + wrapper.cache_info = cached_func.cache_info + return wrapper + + return decorator diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py new file mode 100644 index 0000000000000000000000000000000000000000..3989c75e5462d11d5ca229b757f4e5b45c7ee013 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides MLIR's OP helper functions +""" + + +import inspect +from functools import wraps + +from ..._mlir import ir + + +def dsl_user_op(opFunc): + @wraps(opFunc) + def wrapper(*args, **kwargs): + loc = kwargs.pop("loc", None) + if loc is None: + frame = inspect.currentframe().f_back + file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0) + loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc) + res_or_list = opFunc(*args, **kwargs, loc=loc) + return res_or_list + + return wrapper diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..7b11474c6b5b4fd30fb1feb6fae792fc9e059686 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py @@ -0,0 +1,616 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides helper functions that are generated by the preprocessor. +The preprocessor read through python's ast and changes the input code. +""" + +from typing import Callable, Iterator, Optional, overload +from typing_extensions import deprecated +import warnings +import inspect +from types import BuiltinFunctionType +from functools import lru_cache +from inspect import getmembers + +from .utils.logger import log +from .common import * + +from ._mlir_helpers.arith import ArithValue + + +class Executor: + """ + The Executor class handles dynamic and compile-time (constexpr) execution + of "for" loops and "if-else-elif" statements. + + Methods: + set_functions: Assigns the functions for checking loop bounds and + conditional evaluation. + + for_execute: Generates MLIR for OP + while_execute: Generates MLIR while OP + if_execute: generate MLIR if OP + """ + + def __init__(self): + self._is_dynamic_expression = None + self._loop_execute_range_dynamic = None + self._if_dynamic = None + self._while_dynamic = None + self._compare_executor = None + self._any_executor = None + self._all_executor = None + self._builtin_redirector = None + + def set_functions( + self, + *, + is_dynamic_expression: Callable, + loop_execute_range_dynamic: Callable, + if_dynamic: Callable, + while_dynamic: Callable, + compare_executor: Callable, + any_executor: Callable = None, + all_executor: Callable = None, + builtin_redirector: Callable = None, + ): + self._is_dynamic_expression = is_dynamic_expression + self._loop_execute_range_dynamic = loop_execute_range_dynamic + self._if_dynamic = if_dynamic + self._while_dynamic = while_dynamic + self._compare_executor = compare_executor + self._any_executor = any_executor + self._all_executor = all_executor + self._builtin_redirector = builtin_redirector + + @staticmethod + def convert_to_list(x): + """This function is used to convert x to a list. + If x is None, return an empty list. + If x is not a list, return a list containing x. + Otherwise, return x itself. + """ + if x is None: + return [] + if not isinstance(x, list): + return [x] + return x + + @staticmethod + def converge_ret_val(res): + """This function is used to converge res (the return value) of the function. + If res is None, return None. + If res is a list and has only one element, return the element. + Otherwise, return res itself. + """ + if res is None: + return res + elif isinstance(res, list) and len(res) == 1: + return res[0] + return res + + def for_execute( + self, + func, + start, + stop, + step, + write_args=[], + full_write_args_count=0, + write_args_names=[], + unroll=-1, + unroll_full=False, + prefetch_stages=None, + ): + assert ( + self._loop_execute_range_dynamic + ), "Functions must be set before execution." + log().debug("start [%s] stop [%s] step [%s]", start, stop, step) + + return self._loop_execute_range_dynamic( + func, + start, + stop, + step, + write_args, + full_write_args_count, + write_args_names, + unroll, + unroll_full, + prefetch_stages, + ) + + def if_execute( + self, + pred, + then_block: Callable, + else_block: Optional[Callable] = None, + write_args=[], + full_write_args_count=0, + write_args_names=[], + ): + assert self._if_dynamic, "Functions must be set before execution." + + # MLIR generation + return self._if_dynamic( + pred, + then_block, + else_block, + write_args, + full_write_args_count, + write_args_names, + ) + + def while_execute( + self, + pred, + while_before_block: Callable, + while_after_block: Callable, + write_args=[], + full_write_args_count=0, + write_args_names=[], + ): + assert self._while_dynamic, "Functions must be set before execution." + + # MLIR generation + return self._while_dynamic( + while_before_block, + while_after_block, + write_args, + full_write_args_count, + write_args_names, + ) + + +# ============================================================================= +# Decorator +# ============================================================================= + +executor = Executor() + + +def loop_selector( + start, + stop, + step, + *, + write_args=[], + full_write_args_count=0, + write_args_names=[], + unroll=-1, + unroll_full=False, + prefetch_stages=None, +): + log().debug( + "start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]", + start, + stop, + step, + write_args, + full_write_args_count, + write_args_names, + unroll, + unroll_full, + prefetch_stages, + ) + from .typing import Integer, Numeric + + def _maybe_upcast(value): + if isinstance(value, Integer): + value = value.ir_value() + + return value + + start = _maybe_upcast(start) + stop = _maybe_upcast(stop) + step = _maybe_upcast(step) + + def ir_loop(func): + return executor.for_execute( + func, + start, + stop, + step, + write_args, + full_write_args_count, + write_args_names, + unroll, + unroll_full, + prefetch_stages, + ) + + return ir_loop + + +def if_selector(pred, write_args=[]): + log().debug("pred [%s] write_args [%s]", pred, write_args) + # Handle Numeric types here? + + from .typing import Numeric + + if isinstance(pred, Numeric): + pred = pred.value + + def ir_loop(func): + return func(pred, *write_args) + + return ir_loop + + +def while_selector(pred, write_args=[]): + def ir_while_loop(func): + return func(pred, *write_args) + + return ir_while_loop + + +def while_executor( + pred, + while_before_block: Callable, + while_after_block: Callable, + write_args=[], + full_write_args_count=0, + write_args_names=[], +): + return executor.while_execute( + pred, + while_before_block, + while_after_block, + write_args, + full_write_args_count, + write_args_names, + ) + + +def if_executor( + pred, + then_block: Callable, + else_block: Optional[Callable] = None, + write_args=[], + full_write_args_count=0, + write_args_names=[], +): + return executor.if_execute( + pred, + then_block, + else_block, + write_args, + full_write_args_count, + write_args_names, + ) + + +# ============================================================================= +# Range +# ============================================================================= + + +class range: + """ + A range-like object for dynamic loop iteration in the DSL. + + This class provides a range interface similar to Python's built-in range, + but is designed to be preprocessed into constructs for dynamic + loop execution. + + The class supports both single-argument (stop) and three-argument + (start, stop, step) constructors with additional parameters for loop + optimization: + + - unroll: Number of iterations to unroll (0 or 1 = no unrolling) + - unroll_full: Whether to fully unroll the loop + - prefetch_stages: Number of prefetch stages to generate + """ + + @overload + def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None): + pass + + @overload + def __new__( + cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None + ): + pass + + def __new__(cls, *args, **kwargs): + raise DSLRuntimeError("dynamic range should be always preprocessed to IR") + + def __iter__(self) -> Iterator[int]: + raise DSLRuntimeError("dynamic range should be always preprocessed to IR") + + +@deprecated( + "range_dynamic is deprecated and will be removed in the future, please remove it." +) +def range_dynamic(*args, **kwargs): + raise DSLRuntimeError("range_dynamic should be always preprocessed to IR") + + +def range_constexpr(*args): + raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.") + + +# ============================================================================= +# If expressions +# ============================================================================= + + +def const_expr(expression): + """ + This function is used to check if the expression is a python value. + If the expression is a python value, return the boolean value of the expression. + If the expression is a dynamic expression, raise an error. + """ + from .typing import Numeric + + failed = False + + if isinstance(expression, Numeric): + if isinstance(expression.value, (int, float, bool)): + return expression.value + else: + failed = True + elif executor._is_dynamic_expression(expression): + failed = True + + if failed: + raise DSLRuntimeError( + f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).", + context={ + "If your expression depends on dynamic values": "Remove `const_expr()`", + }, + ) + return expression + + +@deprecated( + "dynamic_expr is deprecated and will be removed in the future, please remove it." +) +def dynamic_expr(expression): + return expression + + +# ============================================================================= +# Assertion & casting +# ============================================================================= + + +def assert_executor(test, msg=None): + from .typing import Numeric + + fail = False + # Implicit convert dynamic expression to bool is not allowed + # So here explicitly do a None check + if test is not None and executor._is_dynamic_expression(test): + if isinstance(test, Numeric): + try: + test = test.to(bool) + except: + fail = True + else: + fail = True + + if not fail: + assert test, msg + else: + raise DSLRuntimeError( + "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", + suggestion="Please replace with runtime assert.", + ) + + +def bool_cast(value): + if executor._is_dynamic_expression(value): + raise DSLRuntimeError( + "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", + suggestion="Please explicitly convert to boolean with expressions like comparision.", + ) + return bool(value) + + +def compare_executor(left, comparators, ops): + """ + Executes comparison operations with a left operand and a list of comparators. + + Args: + left: The leftmost value in the comparison chain + comparators: A list of values to compare against + ops: A list of comparison operators to apply + + Returns: + The result of the comparison chain + + Raises: + AssertionError: If the executor function is not set before execution + """ + assert ( + executor._compare_executor is not None + ), "Function must be set before execution." + return executor._compare_executor(left, comparators, ops) + + +def any_executor(iterable): + """Executes the 'any' operation on an iterable, handling both dynamic and static expressions. + + :param iterable: An iterable to check if any elements evaluate to True + :type iterable: Iterable + :return: boolean of Python value or IR value + :rtype: bool or cutlass.Boolean + + """ + if executor._any_executor and executor._is_dynamic_expression(iterable): + return executor._any_executor(iterable) + else: + return any(iterable) + + +def all_executor(iterable): + """Executes the 'all' operation on an iterable, handling both dynamic and static expressions. + + :param iterable: An iterable to check if all elements evaluate to True + :type iterable: Iterable + :return: boolean of Python value or IR value + :rtype: bool or cutlass.Boolean + """ + if executor._all_executor and executor._is_dynamic_expression(iterable): + return executor._all_executor(iterable) + else: + return all(iterable) + + +# ============================================================================= +# Control flow checks +# ============================================================================= +class DSLOptimizationWarning(Warning): + """ + This warning is used to warn the user about the optimization related issues in DSL. + """ + + def __init__(self, message): + self.message = message + super().__init__() + + def __str__(self): + return self.message + + +def range_value_check(*args): + """ + Ensure all `range_constexpr` bounds are compile-time constants (Python ints). + """ + try: + args = tuple(arg.__index__() for arg in args) + + # Compute range size and warn if it's too large + start = 0 + end = 0 + step = 1 + if len(args) == 1: + end = args[0] + elif len(args) == 2: + start = args[0] + end = args[1] + elif len(args) == 3: + start = args[0] + end = args[1] + step = args[2] + + range_length = (abs(end - start) - 1) // abs(step) + 1 + if range_length >= 64: + warnings.warn( + f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.", + category=DSLOptimizationWarning, + stacklevel=2, + ) + + return (start, end, step) + except: + raise DSLRuntimeError( + "`range_constexpr` requires constexpr (compile-time constant) for all arguments.", + suggestion="Use `range` instead of `range_constexpr`.", + ) + + +def range_perf_warning(filename, lineno, *args): + has_dynamic_expr = False + for arg in args: + if executor._is_dynamic_expression(arg): + has_dynamic_expr = True + break + if not has_dynamic_expr: + warnings.warn_explicit( + ( + "This loop is no longer unrolled and may cause performance regression. " + "Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants." + ), + category=DSLOptimizationWarning, + filename=filename, + lineno=lineno, + ) + + +@lru_cache(maxsize=1) +def _get_self_module(): + """ + This function is used to get the owning module of this function. + """ + return inspect.getmodule(_get_self_module) + + +def cf_symbol_check(symbol): + """ + Check if the symbol is control flow symbol from current module. + """ + + failed = False + name = symbol.__name__ + self_module = _get_self_module() + if inspect.ismodule(symbol): + name = "range" + if not self_module.__name__.startswith(symbol.__name__): + failed = True + else: + owning_module = inspect.getmodule(symbol) + if owning_module != self_module: + failed = True + + if failed: + raise DSLRuntimeError( + f"Incorrect {symbol.__name__} is used.", + suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.", + ) + + +def redirect_builtin_function(fcn): + """ + This function is used to redirect built-in function call + to the function defined in DSL package. + """ + # Only redirect if it's a built-in + if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector: + return executor._builtin_redirector(fcn) + return fcn + + +def copy_members(dest, src): + """ + Copies all non-callable, non-dunder members from src to dest if they exist in src. + Skips members that are callables or have names starting with double underscores. + """ + if id(dest) == id(src): + return + + members = getmembers(dest) + for name, value in members: + if ( + name.startswith("__") + or isinstance(value, Callable) + or not hasattr(src, name) + ): + continue + setattr(dest, name, getattr(src, name)) + + +def get_locals_or_none(locals, symbols): + """ + Given a locals() dictionary and a list of symbol names, return a list of their values + in the same order as the symbols list. If a symbol is not present in locals, None is returned + for that symbol. + """ + variables = [] + for symbol in symbols: + if symbol in locals: + variables.append(locals[symbol]) + else: + variables.append(None) + return variables diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..11f2d1ae84405a13f7fffd241c6e6bdd6e167010 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py @@ -0,0 +1,1958 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module defines the `DSLPreprocessor` class, which acts as a Python preprocessor. +It uses Python's AST and rewrites specific Python statements such as `for` and `if-else`. + +The preprocessor operates on the following constructs: + - `for` loops: + - Rewrites `for` loops with the `@loop_selector` decorator. + - Supports `range`, `range_dynamic` for loop iteration. + - `if-elif-else` statements: + - Rewrites conditional statements with the `@if_selector` decorator. + - Supports `dynamic_expr` and `const_expr` in the condition expressions. + +Additionally, both `for` loops and `if-else` statements require `yield` +operation generation. The preprocessor handles this by: + - Using a `ScopeManager` to track symbols across different scopes during AST traversal. + - Identifying read-only, read-write, and active variables for DSL constructs. + - Generating `yield` operations for symbols that are classified as read-write or write. + +It is designed to be generic and can handle `for` and `if` constructs from other dialects. +In such cases, the user's DSL should implement `@loop_selector` and `@if_selector` +to generate dialect-specific operations for `for` and `if` statements. +""" + +import ast +import importlib +import inspect +import textwrap +import warnings +from dataclasses import dataclass +from typing import List, Set, Dict, Any, Callable, Optional +from types import ModuleType +from collections import OrderedDict +from copy import deepcopy + +from .common import * +from .utils.logger import log + + +class OrderedSet: + """ + A deterministic set implementation for ordered operations. + """ + + def __init__(self, iterable=None): + self._dict = dict.fromkeys(iterable or []) + + def add(self, item): + self._dict[item] = None + + def __iter__(self): + return iter(self._dict) + + def __and__(self, other): + return OrderedSet(key for key in self._dict if key in other) + + def __or__(self, other): + new_dict = self._dict.copy() + new_dict.update(dict.fromkeys(other)) + return OrderedSet(new_dict) + + def __sub__(self, other): + return OrderedSet(key for key in self._dict if key not in other) + + def intersections(self, others): + """Compute the intersection of this set with multiple other sets. + + :param others: A list of sets to compute intersections with + :type others: List[Set[str]] + :return: A new ordered set containing elements that appear in this set + and at least one of the other sets + """ + result = OrderedSet() + for key in self._dict: + for other in reversed(others): + if key in other: + result.add(key) + break + return result + + +@dataclass +class ImportInfo: + """ + Information about an import expression. + """ + module_path: str + attr_name: Optional[str] + alias_name: str + + +@dataclass +class ScopeManager: + """ + Manages symbol scopes during AST traversal. + Manage nested scopes during transformations. + """ + + scopes: List[Set[str]] + + @classmethod + def create(cls) -> "ScopeManager": + return cls([]) + + def add_to_scope(self, name: str) -> None: + if name == "_": + return + self.scopes[-1].add(name) + + def get_active_symbols(self) -> List[Set[str]]: + return self.scopes.copy() + + def __enter__(self) -> "ScopeManager": + self.scopes.append(set()) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.scopes.pop() + + +class DSLPreprocessor(ast.NodeTransformer): + """ + A preprocessor for transforming Python ASTs. It supports: + + - Rewriting `for` loops with the `@loop_selector` decorator. + - Rewriting `if-elif-else` statements with the `@if_selector` decorator. + - Generating `yield` operations for read-write or write symbols. + """ + + DECORATOR_FOR_STATEMENT = "loop_selector" + DECORATOR_IF_STATEMENT = "if_selector" + DECORATOR_WHILE_STATEMENT = "while_selector" + IF_EXECUTOR = "if_executor" + WHILE_EXECUTOR = "while_executor" + ASSERT_EXECUTOR = "assert_executor" + BOOL_CAST = "bool_cast" + IMPLICIT_DOWNCAST_NUMERIC_TYPE = "implicitDowncastNumericType" + SUPPORTED_FOR_RANGE_STATEMENTS = {"range", "range_dynamic", "range_constexpr"} + COMPARE_EXECUTOR = "compare_executor" + ANY_EXECUTOR = "any_executor" + ALL_EXECUTOR = "all_executor" + + def __init__(self, client_module_name): + super().__init__() + self.counter = 0 # Unique function names for multiple loops + self.scope_manager = ScopeManager.create() + self.processed_functions = set() + self.function_counter = 0 + self.function_name = "" + self.class_name = None + self.file_name = "" + self.function_depth = 0 + self.local_closures = set() + self.function_globals = None + self.client_module_name = client_module_name + self.import_top_module = False + + def _create_module_attribute( + self, + func_name, + *, + top_module_name="_dsl_", + submodule_name="ast_helpers", + lineno=None, + col_offset=None, + ): + # If we simply copy location from origin node, it contains a way to wide range, which cause location in traceback to be wrong. + def set_location(node, lineno, col_offset): + if lineno and col_offset: + node.lineno = lineno + node.end_lineno = lineno + node.col_offset = col_offset + node.end_col_offset = col_offset + + base = ast.Name(id=top_module_name, ctx=ast.Load()) + set_location(base, lineno, col_offset) + if submodule_name: + base = ast.Attribute(value=base, attr=submodule_name, ctx=ast.Load()) + set_location(base, lineno, col_offset) + node = ast.Attribute(value=base, attr=func_name, ctx=ast.Load()) + set_location(node, lineno, col_offset) + return node + + def _get_module_imports(self, decorated_func): + """Extract imports from the module containing the decorated function""" + imports = [] + + # Get the module containing the decorated function + if module := inspect.getmodule(decorated_func): + try: + # Get the module source code + source = inspect.getsource(module) + module_ast = ast.parse(source) + + # Extract imports from the full module + alias = lambda n: n.asname if n.asname else n.name + for node in ast.walk(module_ast): + if isinstance(node, ast.Import): + for name in node.names: + imports.append( + ImportInfo( + module_path=name.name, + attr_name=None, + alias_name=alias(name), + ) + ) + elif isinstance(node, ast.ImportFrom): + module_name = node.module + if node.level > 0: + # Handle relative imports + package_name = module.__package__.rsplit( + ".", node.level - 1 + )[0] + module_name = f"{package_name}.{module_name}" + for name in node.names: + imports.append( + ImportInfo( + module_path=module_name, + attr_name=name.name, + alias_name=alias(name), + ) + ) + except (IOError, TypeError): + pass + + return imports + + def exec(self, function_name, original_function, code_object, exec_globals): + # Get imports from the original module + module_imports = self._get_module_imports(original_function) + + # Import all required modules + for import_info in module_imports: + module_path, attr_name, alias_name = ( + import_info.module_path, + import_info.attr_name, + import_info.alias_name, + ) + try: + module = importlib.import_module(module_path) + if attr_name: + if attr_name == "*": + if hasattr(module, "__all__"): + attrs = module.__all__ + else: + attrs = [ + name for name in dir(module) if not name.startswith("_") + ] + else: + attrs = [attr_name] + + for attr in attrs: + alias = attr if attr_name == "*" else alias_name + exec_globals[alias] = getattr(module, attr) + else: + exec_globals[alias_name] = module + except (ImportError, AttributeError) as e: + raise ImportError(f"Failed to import {module_path}: {str(e)}") + + # Execute the transformed code + log().info( + "ASTPreprocessor Executing transformed code for function [%s]", + function_name, + ) + exec(code_object, exec_globals) + return exec_globals.get(function_name) + + @staticmethod + def print_ast(transformed_tree=None): + print("#", "-" * 40, "Transformed AST", "-" * 40) + unparsed_code = ast.unparse(transformed_tree) + print(unparsed_code) + print("#", "-" * 40, "End Transformed AST", "-" * 40) + + def make_func_param_name(self, base_name, used_names): + """Generate a unique parameter name that doesn't collide with existing names.""" + if base_name not in used_names: + return base_name + + i = 0 + while f"{base_name}_{i}" in used_names: + i += 1 + return f"{base_name}_{i}" + + def transform_function(self, func_name, function_pointer): + """ + Transforms a function. + """ + # Skip if the function has already been processed + if function_pointer in self.processed_functions: + log().info( + "ASTPreprocessor Skipping already processed function [%s]", func_name + ) + return [] + + # Step 1. Parse the given function + file_name = inspect.getsourcefile(function_pointer) + lines, start_line = inspect.getsourcelines(function_pointer) + dedented_source = textwrap.dedent("".join(lines)) + tree = ast.parse(dedented_source, filename=file_name) + # Bump the line numbers so they match the real source file + ast.increment_lineno(tree, start_line - 1) + + # Step 1.2 Check the decorator + if not self.check_decorator(tree.body[0]): + log().info( + "[%s] - Skipping function due to missing decorator", + func_name, + ) + return [] + + self.processed_functions.add(function_pointer) + log().info("ASTPreprocessor Transforming function [%s]", func_name) + + # Step 2. Transform the function + transformed_tree = self.visit(tree) + + # Step 3. Import cutlass and base_dsl + top_module_name = ".".join(self.client_module_name) + import_stmts = [] + if self.import_top_module: + import_stmts.append(ast.Import(names=[ast.alias(name=top_module_name)])) + import_stmts.append( + ast.Import( + names=[ast.alias(name=f"{top_module_name}.base_dsl", asname="_dsl_")] + ) + ) + transformed_tree.body = import_stmts + transformed_tree.body + + # Step 4. Import cutlass and base_dsl + ast.fix_missing_locations(transformed_tree) + combined_body = transformed_tree.body + + # Step 5. Return the transformed tree + return combined_body + + def check_early_exit(self, tree, kind): + """ + Checks if a given region or scope in the provided Python code has early exits. + """ + + class EarlyExitChecker(ast.NodeVisitor): + def __init__(self, kind): + self.has_early_exit = False + self.early_exit_node = None + self.early_exit_type = None + self.kind = kind + self.loop_nest_level = 0 + + # Early exit is not allowed in any level of dynamic control flow + def visit_Return(self, node): + self.has_early_exit = True + self.early_exit_node = node + self.early_exit_type = "return" + + def visit_Raise(self, node): + self.has_early_exit = True + self.early_exit_node = node + self.early_exit_type = "raise" + + def visit_Break(self, node): + # For break/continue in inner loops, we don't consider it as early exit + if self.loop_nest_level == 0 and self.kind != "if": + self.has_early_exit = True + self.early_exit_node = node + self.early_exit_type = "break" + + def visit_Continue(self, node): + if self.loop_nest_level == 0 and self.kind != "if": + self.has_early_exit = True + self.early_exit_node = node + self.early_exit_type = "continue" + + def visit_For(self, node): + self.loop_nest_level += 1 + self.generic_visit(node) + self.loop_nest_level -= 1 + + def visit_While(self, node): + self.loop_nest_level += 1 + self.generic_visit(node) + self.loop_nest_level -= 1 + + checker = EarlyExitChecker(kind) + checker.generic_visit(tree) + if not checker.has_early_exit: + return + raise DSLAstPreprocessorError( + message=f"Early exit ({checker.early_exit_type}) is not allowed in `{self.function_name}`" + + (f" in `{self.class_name}`" if self.class_name else ""), + filename=self.file_name, + snippet=ast.unparse(tree), + suggestion=( + "If predicates are constant expression, write like " + "`if const_expr(...)` or `for ... in range_constexpr(...)`. " + "In that case, early exit will be executed by Python " + "interpreter, so it's supported." + ), + ) + + def is_node_constexpr(self, node) -> bool: + """ + Determines if the node is a constexpr. + Supported nodes are if, while statements. + """ + if isinstance(node, ast.If) or isinstance(node, ast.While): + if isinstance(node.test, ast.Call): + func = node.test.func + + if isinstance(func, ast.Attribute) and func.attr == "const_expr": + return True + + elif isinstance(func, ast.Name) and func.id == "const_expr": + return True + return False + + def _get_range_kind(self, iter_node): + """ + Return "range", "range_dynamic", "range_constexpr" or None for the iterable + """ + if isinstance(iter_node, ast.Call): + func = iter_node.func + if ( + isinstance(func, ast.Name) + and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS + ): + return func.id, True, len(iter_node.keywords) != 0 + if ( + isinstance(func, ast.Attribute) + and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS + ): + return func.attr, False, len(iter_node.keywords) != 0 + return None, None, None + + def transform(self, original_function, exec_globals): + """ + Transforms the provided function using the preprocessor. + """ + self.file_name = inspect.getsourcefile(original_function) + self.function_globals = exec_globals + transformed_tree = self.transform_function( + original_function.__name__, original_function + ) + self.function_globals = None + unified_tree = ast.Module(body=transformed_tree, type_ignores=[]) + unified_tree = ast.fix_missing_locations(unified_tree) + + return unified_tree + + def analyze_region_variables( + self, node: Union[ast.For, ast.If], active_symbols: List[Set[str]] + ): + """ + Analyze variables in different code regions to identify read-only, write-only, + and active variables for DSL constructs. + """ + + # we need orderedset to keep the insertion order the same. otherwise generated IR is different each time + write_args = OrderedSet() + invoked_args = OrderedSet() + local_closure = self.local_closures + file_name = self.file_name + region_node = node + + class RegionAnalyzer(ast.NodeVisitor): + force_store = False + + def visit_Name(self, node): + """ + Mark every store as write. + """ + if isinstance(node.ctx, ast.Store) or self.force_store: + write_args.add(node.id) + + def visit_Subscript(self, node): + # When subscript occurs on the lhs of an assignment, the `Name` is still a load, but `Subscript` is marked as `Store`. + # We need to force the store for the `Name` to be marked as write. + if isinstance(node.ctx, ast.Store): + self.force_store = True + self.visit(node.value) + self.force_store = False + self.visit(node.slice) + else: + self.generic_visit(node) + + def visit_Assign(self, node): + self.force_store = True + [self.visit(target) for target in node.targets] + self.force_store = False + self.visit(node.value) + + def visit_AugAssign(self, node): + self.force_store = True + self.visit(node.target) + self.force_store = False + self.visit(node.value) + + @staticmethod + def get_call_base(func_node): + if isinstance(func_node, ast.Attribute): + # If the .value is another Attribute, keep digging + if isinstance(func_node.value, ast.Attribute): + return RegionAnalyzer.get_call_base(func_node.value) + # If the .value is a Name, that's our base + elif isinstance(func_node.value, ast.Name): + return func_node.value.id + else: + # Could be something else (lambda, call, etc.) + return None + elif isinstance(func_node, ast.Name): + return None + return None + + @staticmethod + def get_function_name(func_node: ast.Call): + if isinstance(func_node.func, ast.Name): + function_name = func_node.func.id + # Check if it's a method or attribute call + elif isinstance(func_node.func, ast.Attribute): + function_name = func_node.func.attr + else: + function_name = None + return function_name + + def visit_Call(self, node): + base_name = RegionAnalyzer.get_call_base(node.func) + + if isinstance(node.func, ast.Name): + func_name = node.func.id + if func_name in local_closure: + raise DSLAstPreprocessorError( + f"Function `{func_name}` is a closure and is not supported in for/if statements", + filename=file_name, + snippet=ast.unparse(region_node), + ) + + # Classes are mutable by default. Mark them as write. If they are + # dataclass(frozen=True), treat them as read in runtime. + if base_name is not None and base_name not in ("self"): + invoked_args.add(base_name) + + self.generic_visit(node) + + analyzer = RegionAnalyzer() + analyzer.visit(ast.Module(body=node)) + + # If arg is both write and invoke, remove from invoked_args + invoked_args = invoked_args - write_args + + write_args = list(write_args.intersections(active_symbols)) + invoked_args = list(invoked_args.intersections(active_symbols)) + + return write_args + invoked_args, len(write_args) + + def extract_range_args(self, iter_node): + args = iter_node.args + if len(args) == 1: + return ( + self.visit(ast.Constant(value=0)), + self.visit(args[0]), + self.visit(ast.Constant(value=1)), + False, + ) + elif len(args) == 2: + return ( + self.visit(args[0]), + self.visit(args[1]), + self.visit(ast.Constant(value=1)), + False, + ) + elif len(args) == 3: + return self.visit(args[0]), self.visit(args[1]), self.visit(args[2]), True + else: + raise DSLAstPreprocessorError( + "Unsupported number of arguments in range", filename=self.file_name + ) + + def extract_unroll_args(self, iter_node): + keywords = {kw.arg: kw.value for kw in iter_node.keywords} + return ( + keywords.get("unroll", ast.Constant(value=-1)), + keywords.get("unroll_full", ast.Constant(value=False)), + ) + + def issue_deprecation_warning(self, *, message, category, filename, lineno): + warnings.simplefilter("always", category) # turn off filter + warnings.warn_explicit( + message, category=category, filename=filename, lineno=lineno + ) + warnings.simplefilter("default", category) # reset filter + + def extract_prefetch_stages_args(self, iter_node): + keywords = {kw.arg: kw.value for kw in iter_node.keywords} + if "pipelining" in keywords: + self.issue_deprecation_warning( + message="pipelining is deprecated, use prefetch_stages instead", + category=DeprecationWarning, + filename=self.file_name, + lineno=iter_node.lineno, + ) + return keywords.get("pipelining", ast.Constant(value=None)) + return keywords.get("prefetch_stages", ast.Constant(value=None)) + + def create_loop_function( + self, + func_name, + node, + start, + stop, + step, + unroll, + unroll_full, + prefetch_stages, + write_args, + full_write_args_count, + ): + """ + Creates a loop body function with the `loop_selector` decorator. + """ + + func_args = [ast.arg(arg=node.target.id, annotation=None)] + func_args += [ast.arg(arg=var, annotation=None) for var in write_args] + + # Create the loop body + transformed_body = [] + for stmt in node.body: + transformed_stmt = self.visit(stmt) # Recursively visit inner statements + if isinstance(transformed_stmt, list): + transformed_body.extend(transformed_stmt) + else: + transformed_body.append(transformed_stmt) + + # Handle the return for a single iterated argument correctly + if len(write_args) == 0: + transformed_body.append(ast.Return()) + else: + transformed_body.append( + ast.Return( + value=ast.List( + elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], + ctx=ast.Load(), + ) + ) + ) + + # Define the decorator with parameters + decorator = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.DECORATOR_FOR_STATEMENT, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[start, stop, step], + keywords=[ + ast.keyword(arg="unroll", value=unroll), + ast.keyword(arg="unroll_full", value=unroll_full), + ast.keyword(arg="prefetch_stages", value=prefetch_stages), + ast.keyword( + arg="write_args", + value=self.generate_get_locals_or_none_call(write_args), + ), + ast.keyword( + arg="full_write_args_count", + value=ast.Constant(value=full_write_args_count), + ), + ast.keyword( + arg="write_args_names", + value=ast.List( + elts=[ast.Constant(value=arg) for arg in write_args], + ctx=ast.Load(), + ), + ), + ], + ), + node, + ) + + return ast.copy_location( + ast.FunctionDef( + name=func_name, + args=ast.arguments( + posonlyargs=[], + args=func_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=transformed_body, + decorator_list=[decorator], + ), + node, + ) + + def visit_BoolOp(self, node): + # Visit child nodes first + self.generic_visit(node) + + # It is necessary to expand short circuit evaluation explicit here + # Although we do not support inline if-else for IR generation, this is actually evaluated in Python + # So it's fine here + # Transform "and" to "and_" + if isinstance(node.op, ast.And): + # Create an if-else statement in AST form + # if type(lhs) == bool and lhs == False: + # return lhs + # else + # return and_(lhs, rhs) + short_circuit_value = ast.Constant(value=False) + helper_func = self._create_module_attribute( + "and_", + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ) + self.import_top_module = True + # Transform "or" to "or_" + elif isinstance(node.op, ast.Or): + # Create an if-else statement in AST form + # if type(lhs) == bool and lhs == True: + # return lhs + # else + # return or_(lhs, rhs) + short_circuit_value = ast.Constant(value=True) + helper_func = self._create_module_attribute( + "or_", + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ) + self.import_top_module = True + else: + # BoolOp should be either And or Or + raise DSLAstPreprocessorError( + f"Unsupported boolean operation: {node.op}", + filename=self.file_name, + snippet=ast.unparse(node), + ) + + def short_circuit_eval(value, short_circuit_value): + return ast.BoolOp( + op=ast.And(), + values=[ + ast.Compare( + left=ast.Call( + func=ast.Name(id="type", ctx=ast.Load()), + args=[value], + keywords=[], + ), + ops=[ast.Eq()], + comparators=[ast.Name(id="bool", ctx=ast.Load())], + ), + ast.Compare( + left=value, + ops=[ast.Eq()], + comparators=[short_circuit_value], + ), + ], + ) + + lhs = node.values[0] + + for i in range(1, len(node.values)): + test = short_circuit_eval(lhs, short_circuit_value) + lhs = ast.IfExp( + test=test, + body=lhs, + orelse=ast.Call( + func=helper_func, + args=[lhs, node.values[i]], + keywords=[], + ), + ) + + return ast.copy_location(lhs, node) + + def visit_UnaryOp(self, node): + # Visit child nodes first + self.generic_visit(node) + + # Transform "not" to "~" as we overload __invert__ + if isinstance(node.op, ast.Not): + func_name = self._create_module_attribute( + "not_", + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ) + self.import_top_module = True + return ast.copy_location( + ast.Call(func=func_name, args=[node.operand], keywords=[]), node + ) + + return node + + def _insert_range_value_check(self, node): + """ + Insert a check for range arguments + """ + range_inputs = node.iter.args + check_call = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + "range_value_check", lineno=node.lineno, col_offset=node.col_offset + ), + args=range_inputs, + keywords=[], + ), + node.iter, + ) + node.iter = ast.copy_location( + ast.Call( + func=ast.Name(id="range", ctx=ast.Load()), + args=[ast.Starred(value=check_call, ctx=ast.Load())], + keywords=[], + ), + node.iter, + ) + + def _insert_cf_symbol_check(self, func): + """ + Insert a check for range symbol + """ + check_call = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + "cf_symbol_check", lineno=func.lineno, col_offset=func.col_offset + ), + args=[deepcopy(func)], + keywords=[], + ), + func, + ) + return ast.Expr(check_call) + + def visit_For(self, node): + # For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop. + range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter) + if range_kind == "range_constexpr" or range_kind == None: + self.generic_visit(node) + if range_kind == "range_constexpr": + check_call = self._insert_cf_symbol_check(node.iter.func) + # Rewrite range_constexpr to range + node.iter.func = ast.Name(id="range", ctx=ast.Load()) + self._insert_range_value_check(node) + return [check_call, node] + return node + + active_symbols = self.scope_manager.get_active_symbols() + + with self.scope_manager: + if isinstance(node.target, ast.Name): + self.scope_manager.add_to_scope(node.target.id) + + if range_kind == "range_dynamic": + # Generate a warning + self.issue_deprecation_warning( + message="range_dynamic is deprecated and will be removed in the future, please remove it.", + category=DeprecationWarning, + filename=self.file_name, + lineno=node.iter.lineno, + ) + + warning_call = None + if range_kind == "range" and is_builtin_range and not has_keyword: + # Warn about possible performance regression due to behavior change + warning_call = ast.Expr( + ast.Call( + func=self._create_module_attribute( + "range_perf_warning", + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[ + ast.Constant(value=self.file_name), + ast.Constant(value=node.iter.lineno), + ] + + node.iter.args, + keywords=[], + ) + ) + ast.copy_location(warning_call, node.iter) + + is_prefixed_range = range_kind == "range" and not is_builtin_range + check_call = None + if range_kind == "range_dynamic" or is_prefixed_range: + # Insert a check for range symbol + if not is_prefixed_range: + check_call = self._insert_cf_symbol_check(node.iter.func) + else: + # Get toplevel module + check_call = self._insert_cf_symbol_check(node.iter.func.value) + + new_for_node = self.transform_for_loop(node, active_symbols) + if check_call is not None: + new_for_node = [check_call] + new_for_node + + return new_for_node if warning_call is None else [warning_call] + new_for_node + + @staticmethod + def _hoist_expr_to_assignments(expr, name): + return ast.copy_location( + ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=expr), expr + ) + + def _build_select_and_assign(self, *, name, test, body, orelse, location): + node = ast.copy_location( + ast.Assign( + targets=[ast.Name(id=name, ctx=ast.Store())], + value=ast.IfExp( + test=test, + body=body, + orelse=orelse, + ), + ), + location, + ) + self.generic_visit(node) + return node + + def _handle_negative_step(self, node, start_expr, stop_expr, step_expr): + # hoist start, stop, step to assignments + start_ori_name = f"start_ori_{self.counter}" + start = self._hoist_expr_to_assignments(start_expr, start_ori_name) + stop_ori_name = f"stop_ori_{self.counter}" + stop = self._hoist_expr_to_assignments(stop_expr, stop_ori_name) + step_ori_name = f"step_ori_{self.counter}" + step = self._hoist_expr_to_assignments(step_expr, step_ori_name) + + extra_exprs = [start, stop, step] + + # Handle possible negative step, generates the following code in Python: + # isNegative = step < 0 + isNegative_name = f"isNegative_{self.counter}" + isNegative = ast.copy_location( + ast.Assign( + targets=[ast.Name(id=isNegative_name, ctx=ast.Store())], + value=ast.Compare( + left=ast.Name(id=step_ori_name, ctx=ast.Load()), + ops=[ast.Lt()], + comparators=[ast.Constant(value=0)], + ), + ), + step, + ) + + # start = stop if isNegative else start + start_name = f"start_{self.counter}" + start = self._build_select_and_assign( + name=start_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.Name(id=stop_ori_name, ctx=ast.Load()), + orelse=ast.Name(id=start_ori_name, ctx=ast.Load()), + location=start, + ) + + # stop = start if isNegative else stop + stop_name = f"stop_{self.counter}" + stop = self._build_select_and_assign( + name=stop_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.Name(id=start_ori_name, ctx=ast.Load()), + orelse=ast.Name(id=stop_ori_name, ctx=ast.Load()), + location=stop, + ) + + # step = -step if isNegative else step + step_name = f"step_{self.counter}" + step = self._build_select_and_assign( + name=step_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.UnaryOp( + op=ast.USub(), operand=ast.Name(id=step_ori_name, ctx=ast.Load()) + ), + orelse=ast.Name(id=step_ori_name, ctx=ast.Load()), + location=step, + ) + + # offset = start + stop if isNegative else 0 + offset_name = f"offset_{self.counter}" + offset = self._build_select_and_assign( + name=offset_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.BinOp( + op=ast.Add(), + left=ast.Name(id=start_name, ctx=ast.Load()), + right=ast.Name(id=stop_name, ctx=ast.Load()), + ), + orelse=ast.Constant(value=0), + location=node, + ) + + extra_exprs.append(isNegative) + extra_exprs.append(start) + extra_exprs.append(stop) + extra_exprs.append(step) + extra_exprs.append(offset) + + # Add this to begining of loop body + # for i in range(start, stop, step): + # i = offset - i if isNegative else i + assert isinstance(node.target, ast.Name) + + target_name = node.target.id + target = self._build_select_and_assign( + name=target_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.BinOp( + op=ast.Sub(), + left=ast.Name(id=offset_name, ctx=ast.Load()), + right=ast.Name(id=target_name, ctx=ast.Load()), + ), + orelse=ast.Name(id=target_name, ctx=ast.Load()), + location=node.target, + ) + + node.body.insert(0, target) + + return ( + ast.Name(id=start_name, ctx=ast.Load()), + ast.Name(id=stop_name, ctx=ast.Load()), + ast.Name(id=step_name, ctx=ast.Load()), + extra_exprs, + ) + + def transform_for_loop(self, node, active_symbols): + # Check for early exit and raise exception + self.check_early_exit(node, "for") + if node.orelse: + raise DSLAstPreprocessorError( + "dynamic for loop with else is not supported", + filename=self.file_name, + snippet=ast.unparse(node), + ) + + # Get loop target variable name + target_var_name = None + target_var_is_active_before_loop = False + if isinstance(node.target, ast.Name): + target_var_name = node.target.id + for active_symbol in active_symbols: + if target_var_name in active_symbol: + target_var_is_active_before_loop = True + active_symbols.remove(active_symbol) + break + + # Add necessary exprs to handle this + if target_var_is_active_before_loop: + # Initialize an extra loop carried variable + loop_carried_var_name = f"loop_carried_var_{self.counter}" + pre_loop_expr = ast.copy_location( + ast.Assign( + targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())], + value=ast.Name(id=target_var_name, ctx=ast.Load()), + ), + node, + ) + # append an extra assignment to the loop carried variable + node.body.append( + ast.copy_location( + ast.Assign( + targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())], + value=ast.Name(id=target_var_name, ctx=ast.Load()), + ), + node, + ) + ) + active_symbols.append({loop_carried_var_name}) + + start_expr, stop_expr, step_expr, has_step = self.extract_range_args(node.iter) + unroll, unroll_full = self.extract_unroll_args(node.iter) + prefetch_stages = self.extract_prefetch_stages_args(node.iter) + write_args, full_write_args_count = self.analyze_region_variables( + node, active_symbols + ) + + if has_step and self.client_module_name[0] == "cutlass": + start, stop, step, exprs = self._handle_negative_step( + node, start_expr, stop_expr, step_expr + ) + else: + start, stop, step, exprs = start_expr, stop_expr, step_expr, [] + + if target_var_is_active_before_loop: + exprs.append(pre_loop_expr) + + func_name = f"loop_body_{self.counter}" + self.counter += 1 + + func_def = self.create_loop_function( + func_name, + node, + start, + stop, + step, + unroll, + unroll_full, + prefetch_stages, + write_args, + full_write_args_count, + ) + + assign = self.create_cf_call(func_name, write_args, node) + + # This should work fine as it modifies the AST structure + exprs = exprs + [func_def] + assign + + if target_var_is_active_before_loop: + # Create a new assignment to the target variable + exprs.append( + ast.copy_location( + ast.Assign( + targets=[ast.Name(id=target_var_name, ctx=ast.Store())], + value=ast.Name(id=loop_carried_var_name, ctx=ast.Load()), + ), + node, + ) + ) + + return exprs + + def visit_Assert(self, node): + test = self.visit(node.test) + + args = [ast.keyword(arg="test", value=test)] + if node.msg: + msg = self.visit(node.msg) + args.append(ast.keyword(arg="msg", value=msg)) + + # Rewrite to assert_executor(test, msg) + new_node = ast.Expr( + ast.Call( + func=self._create_module_attribute( + self.ASSERT_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), + args=[], + keywords=args, + ) + ) + + # Propagate line number from original node to new node + ast.copy_location(new_node, node) + return new_node + + def visit_Call(self, node): + func = node.func + # Visit args and kwargs + node.args = [self.visit(arg) for arg in node.args] + node.keywords = [self.visit(kwarg) for kwarg in node.keywords] + + # Rewrite call to some built-in functions + if isinstance(func, ast.Name): + # Check if the function is 'bool' + if func.id == "bool": + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.BOOL_CAST, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[node.args[0]], + keywords=[], + ), + node, + ) + elif func.id in ["any", "all"]: + helper_func = ( + self.ANY_EXECUTOR if func.id == "any" else self.ALL_EXECUTOR + ) + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + helper_func, lineno=node.lineno, col_offset=node.col_offset + ), + args=[node.args[0]], + keywords=[], + ), + node, + ) + elif func.id in ["min", "max"]: + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + func.id, + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[node.args[0], node.args[1]], + keywords=[], + ), + node, + ) + elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): + def create_downcast_call(arg): + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.IMPLICIT_DOWNCAST_NUMERIC_TYPE, + submodule_name="typing", + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[arg], + keywords=[], + ), + arg, + ) + module = self.function_globals.get(func.value.id) + if isinstance(module, ModuleType) and module.__package__.endswith( + "._mlir.dialects" + ): + # Check if argument is Numeric, if so, call ir_value() + args = [] + for arg in node.args: + args.append(create_downcast_call(arg)) + kwargs = [] + for kwarg in node.keywords: + kwargs.append( + ast.copy_location( + ast.keyword( + arg=kwarg.arg, + value=create_downcast_call(kwarg.value), + ), + kwarg, + ) + ) + return ast.copy_location( + ast.Call(func=func, args=args, keywords=kwargs), node + ) + else: + node.func = self.visit(node.func) + + return node + + def visit_ClassDef(self, node): + self.class_name = node.name + self.generic_visit(node) + self.class_name = None + return node + + def _visit_target(self, target): + if isinstance(target, ast.Name): + self.scope_manager.add_to_scope(target.id) + elif isinstance(target, ast.Tuple): + for t in target.elts: + if isinstance(t, ast.Name): + self.scope_manager.add_to_scope(t.id) + + def visit_Assign(self, node): + for target in node.targets: + self._visit_target(target) + self.generic_visit(node) + return node + + def visit_AugAssign(self, node): + self._visit_target(node.target) + self.generic_visit(node) + return node + + def visit_Name(self, node): + isLoad = isinstance(node.ctx, ast.Load) + if node.id in ["max", "min", "any", "all"] and isLoad: + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + "redirect_builtin_function", + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[node], + keywords=[], + ), + node, + ) + elif node.id == "_" and isLoad: + raise DSLAstPreprocessorError("Read '_' is not allowed") + else: + self.generic_visit(node) + return node + + def check_decorator(self, node: ast.AST) -> bool: + """ + Check if the function has the correct decorator for preprocessing. + """ + if not isinstance(node, ast.FunctionDef): + return False + decorator_list = node.decorator_list + if len(decorator_list) == 0: + return False + + for d in decorator_list: + if isinstance(d, ast.Call): + if isinstance(d.func, ast.Attribute): + if d.func.attr in ["jit", "kernel"]: + if d.keywords == []: + return True + for keyword in d.keywords: + if keyword.arg == "preprocess": + try: + if isinstance(keyword.value, ast.Constant): + return keyword.value.value + else: + return ast.literal_eval(keyword.value) + except: + pass + + elif isinstance(d, ast.Attribute): + if d.attr in ["jit", "kernel"]: + return True + + return False + + def remove_dsl_decorator(self, decorator_list): + """ + Remove .jit and .kernel decorators + The decorator can be in two forms: + - @jit(...) + - @jit + """ + new_decorator_list = [] + decorator_names = ["jit", "kernel"] + for d in decorator_list: + is_jit_or_kernel = False + if isinstance(d, ast.Call): + if isinstance(d.func, ast.Attribute): + if d.func.attr in decorator_names: + is_jit_or_kernel = True + elif isinstance(d, ast.Attribute): + if d.attr in decorator_names: + is_jit_or_kernel = True + + if not is_jit_or_kernel: + new_decorator_list.append(d) + return new_decorator_list + + def visit_FunctionDef(self, node): + with self.scope_manager: + self.function_counter += 1 + self.function_name = node.name + if self.function_depth > 0: + self.local_closures.add(node.name) + + self.function_depth += 1 + + # Add function name and arguments + self.scope_manager.add_to_scope(node.name) + for arg in node.args.args: + self.scope_manager.add_to_scope(arg.arg) + + self.generic_visit(node) + + self.function_depth -= 1 + + # Remove .jit and .kernel decorators + node.decorator_list = self.remove_dsl_decorator(node.decorator_list) + return node + + def visit_With(self, node): + with self.scope_manager: + for item in node.items: + if isinstance(item.optional_vars, ast.Name): + self.scope_manager.add_to_scope(item.optional_vars.id) + self.generic_visit(node) + + return node + + def visit_While(self, node): + # Constexpr doesn't get preprocessed + if self.is_node_constexpr(node): + self.generic_visit(node) + check = self._insert_cf_symbol_check(node.test.func) + return [check, node] + + active_symbols = self.scope_manager.get_active_symbols() + + with self.scope_manager: + # Check for early exit and raise exception + self.check_early_exit(node, "while") + + write_args, full_write_args_count = self.analyze_region_variables( + node, active_symbols + ) + func_name = f"while_region_{self.counter}" + self.counter += 1 + + func_def = self.create_while_function( + func_name, node, write_args, full_write_args_count + ) + assign = self.create_cf_call(func_name, write_args, node) + + return [func_def] + assign + + def visit_Try(self, node): + with self.scope_manager: + self.generic_visit(node) + return node + + def visit_ExceptHandler(self, node): + with self.scope_manager: + if node.name: # Exception variable + self.scope_manager.add_to_scope(node.name) + self.generic_visit(node) + return node + + def create_cf_call(self, func_name, yield_args, node): + """Creates the assignment statement for the if function call""" + if not yield_args: + return [ + ast.copy_location( + ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())), node + ) + ] + has_self = False + for i, arg in enumerate(yield_args): + if arg == "self": + has_self = True + yield_args[i] = "yield_self" + break + if len(yield_args) == 1: + assign = ast.Assign( + targets=[ast.Name(id=yield_args[0], ctx=ast.Store())], + value=ast.Name(id=func_name, ctx=ast.Load()), + ) + else: + assign = ast.Assign( + targets=[ + ast.Tuple( + elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args], + ctx=ast.Store(), + ) + ], + value=ast.Name(id=func_name, ctx=ast.Load()), + ) + + if has_self: + fix_self = ast.Expr( + value=ast.Call( + func=self._create_module_attribute( + "copy_members", lineno=node.lineno, col_offset=node.col_offset + ), + args=[ + ast.Name(id="self", ctx=ast.Load()), + ast.Name(id="yield_self", ctx=ast.Load()), + ], + keywords=[], + ) + ) + return [ast.copy_location(assign, node), ast.copy_location(fix_self, node)] + else: + return [ast.copy_location(assign, node)] + + def visit_IfExp(self, node): + """ + Visits an inline if-else expression (ternary operator). + This is the Python equivalent of `x if condition else y`. + """ + self.generic_visit(node) + # Emit + # node if type(pred) == bool else select_(pred, body, orelse) + # so if pred is a python bool, use python to short-circuit and avoid emit arith.select + self.import_top_module = True + return ast.copy_location( + ast.IfExp( + test=ast.Compare( + left=ast.Call( + func=ast.Name(id="type", ctx=ast.Load()), + args=[node.test], + keywords=[], + ), + ops=[ast.Eq()], + comparators=[ast.Name(id="bool", ctx=ast.Load())], + ), + body=node, # Original ternary expression + orelse=ast.Call( + func=self._create_module_attribute( + "select_", top_module_name="cutlass", submodule_name=None + ), + args=[ + node.test, + node.body, + node.orelse, + ], + keywords=[], + ), + ), + node, + ) + + cmpops = { + "Eq": "==", + "NotEq": "!=", + "Lt": "<", + "LtE": "<=", + "Gt": ">", + "GtE": ">=", + "Is": "is", + "IsNot": "is not", + "In": "in", + "NotIn": "not in", + } + def compare_ops_to_str(self, node): + names = [ + ast.Constant(value=self.cmpops[op.__class__.__name__]) for op in node.ops + ] + return ast.List(elts=names, ctx=ast.Load()) + + def visit_Compare(self, node): + self.generic_visit(node) + + comparator_strs = self.compare_ops_to_str(node) + + keywords = [ + ast.keyword(arg="left", value=node.left), + ast.keyword( + arg="comparators", value=ast.List(elts=node.comparators, ctx=ast.Load()) + ), + ast.keyword(arg="ops", value=comparator_strs), + ] + + call = ast.copy_location( + ast.Call( + func=self._create_module_attribute(self.COMPARE_EXECUTOR), + args=[], + keywords=keywords, + ), + node, + ) + + return call + + def visit_If(self, node): + # const_expr doesn't get preprocessed + if self.is_node_constexpr(node): + self.generic_visit(node) + check = self._insert_cf_symbol_check(node.test.func) + return [check, node] + + active_symbols = self.scope_manager.get_active_symbols() + with self.scope_manager: + # Check for early exit and raise exception + self.check_early_exit(node, "if") + + yield_args, full_write_args_count = self.analyze_region_variables( + node, active_symbols + ) + func_name = f"if_region_{self.counter}" + self.counter += 1 + + func_def = self.create_if_function( + func_name, node, yield_args, full_write_args_count + ) + assign = self.create_cf_call(func_name, yield_args, node) + + return [func_def] + assign + + def generate_get_locals_or_none_call(self, write_args): + return ast.Call( + func=self._create_module_attribute("get_locals_or_none"), + args=[ + ast.Call( + func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[] + ), + ast.List( + elts=[ast.Constant(value=arg) for arg in write_args], + ctx=ast.Load(), + ), + ], + keywords=[], + ) + + def create_if_function(self, func_name, node, write_args, full_write_args_count): + test_expr = self.visit(node.test) + pred_name = self.make_func_param_name("pred", write_args) + func_args = [ast.arg(arg=pred_name, annotation=None)] + func_args += [ast.arg(arg=var, annotation=None) for var in write_args] + func_args_then_else = [ast.arg(arg=var, annotation=None) for var in write_args] + + then_body = [] + for stmt in node.body: + transformed_stmt = self.visit(stmt) # Recursively visit inner statements + if isinstance(transformed_stmt, list): + then_body.extend(transformed_stmt) + else: + then_body.append(transformed_stmt) + + # Create common return list for all blocks + return_list = ast.List( + elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], + ctx=ast.Load(), + ) + + # Create common function arguments + func_decorator_arguments = ast.arguments( + posonlyargs=[], args=func_args, kwonlyargs=[], kw_defaults=[], defaults=[] + ) + func_then_else_arguments = ast.arguments( + posonlyargs=[], + args=func_args_then_else, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) + + then_block_name = f"then_block_{self.counter}" + else_block_name = f"else_block_{self.counter}" + elif_region_name = f"elif_region_{self.counter}" + self.counter += 1 + + # Create then block + then_block = ast.copy_location( + ast.FunctionDef( + name=then_block_name, + args=func_then_else_arguments, + body=then_body + [ast.Return(value=return_list)], + decorator_list=[], + ), + node, + ) + + # Decorator keywords + decorator_keywords = [ + ast.keyword( + arg="pred", value=test_expr + ), # ast.Name(id="pred", ctx=ast.Load()) + ast.keyword( + arg="write_args", + value=self.generate_get_locals_or_none_call(write_args), + ), + ] + + # Create decorator + decorator = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.DECORATOR_IF_STATEMENT, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[], + keywords=decorator_keywords, + ), + node, + ) + + # Executor keywords + execute_keywords = [ + ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())), + ast.keyword( + arg="write_args", + value=ast.List( + elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], + ctx=ast.Load(), + ), + ), + ast.keyword( + arg="full_write_args_count", + value=ast.Constant(value=full_write_args_count), + ), + ast.keyword( + arg="write_args_names", + value=ast.List( + elts=[ast.Constant(value=arg) for arg in write_args], + ctx=ast.Load(), + ), + ), + ast.keyword( + arg="then_block", value=ast.Name(id=then_block_name, ctx=ast.Load()) + ), + ] + + # Handle different cases + if not write_args and node.orelse == []: + # No write_args case - only then_block needed + execute_call = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), + args=[], + keywords=execute_keywords, + ), + node, + ) + func_body = [then_block, ast.Return(value=execute_call)] + else: + # Create else block based on node.orelse + if node.orelse: + if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If): + # Handle elif case + elif_node = node.orelse[0] + nested_if_name = elif_region_name + # Recursion for nested elif + nested_if = self.create_if_function( + nested_if_name, elif_node, write_args, full_write_args_count + ) + else_block = ast.FunctionDef( + name=else_block_name, + args=func_then_else_arguments, + body=[ + nested_if, + ast.Return( + value=ast.Name(id=nested_if_name, ctx=ast.Load()) + ), + ], + decorator_list=[], + ) + else: + + else_body = [] + for stmt in node.orelse: + transformed_stmt = self.visit( + stmt + ) # Recursively visit inner statements + if isinstance(transformed_stmt, list): + else_body.extend(transformed_stmt) + else: + else_body.append(transformed_stmt) + + # Regular else block + else_block = ast.FunctionDef( + name=else_block_name, + args=func_then_else_arguments, + body=else_body + [ast.Return(value=return_list)], + decorator_list=[], + ) + else: + # Default else block + else_block = ast.FunctionDef( + name=else_block_name, + args=func_then_else_arguments, + body=[ast.Return(value=return_list)], + decorator_list=[], + ) + + # Add else_block to execute keywords + execute_keywords.append( + ast.keyword( + arg="else_block", value=ast.Name(id=else_block_name, ctx=ast.Load()) + ) + ) + + execute_call = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), + args=[], + keywords=execute_keywords, + ), + node, + ) + func_body = [ + then_block, + ast.copy_location(else_block, node), + ast.Return(value=execute_call), + ] + + return ast.copy_location( + ast.FunctionDef( + name=func_name, + args=func_decorator_arguments, + body=func_body, + decorator_list=[decorator], + ), + node, + ) + + def create_while_function(self, func_name, node, write_args, full_write_args_count): + """Create a while function that looks like: + + @while_selector(pred, write_args=[]) + def while_region(pred, write_args): + def while_before_block(*write_args): + # Note that during eval of pred can possibly alter yield_args + return *pred, write_args + def while_after_block(*write_args): + ...loop_body_transformed... + return write_args + return self.while_executor(pred, write_args, + while_before_block, while_after_block, constexpr) + write_args = while_region(pred, write_args) + + Which will later be executed as psuedo-code: + + # Dynamic mode: + scf.WhileOp(types(write_args), write_args) + with InsertionPoint(before_block): + cond, write_args = while_before_block(*write_args) + scf.ConditionOp(cond, write_args) + with InsertionPoint(after_block): + write_args = while_after_block(write_args) + scf.YieldOp(write_args) + return while_op.results_ + + # Const mode: + cond, write_args = while_before_block(write_args) + while pred: + write_args = body_block(write_args) + cond, write_args = while_before_block(write_args) + return write_args + """ + test_expr = self.visit(node.test) + pred_name = self.make_func_param_name("pred", write_args) + + # Section: decorator construction + decorator_keywords = [ + ast.keyword(arg="pred", value=test_expr), + ast.keyword( + arg="write_args", + value=self.generate_get_locals_or_none_call(write_args), + ), + ] + decorator = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.DECORATOR_WHILE_STATEMENT, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[], + keywords=decorator_keywords, + ), + node, + ) + + # Section: Shared initialization for before and after blocks + while_before_block_name = f"while_before_block_{self.counter}" + while_after_block_name = f"while_after_block_{self.counter}" + self.counter += 1 + block_args_args = [ast.arg(arg=var, annotation=None) for var in write_args] + block_args = ast.arguments( + posonlyargs=[], + args=block_args_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) + + yield_args_ast_name_list = ast.List( + elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], + ctx=ast.Load(), + ) + + # Section: while_before_block FunctionDef, which contains condition + while_before_return_list = ast.List( + elts=[test_expr, yield_args_ast_name_list], + ctx=ast.Load(), + ) + while_before_stmts = [ast.Return(value=while_before_return_list)] + while_before_block = ast.copy_location( + ast.FunctionDef( + name=while_before_block_name, + args=block_args, + body=while_before_stmts, + decorator_list=[], + ), + test_expr, + ) + + # Section: while_after_block FunctionDef, which contains loop body + while_after_stmts = [] + for stmt in node.body: + transformed_stmt = self.visit(stmt) # Recursively visit inner statements + if isinstance(transformed_stmt, list): + while_after_stmts.extend(transformed_stmt) + else: + while_after_stmts.append(transformed_stmt) + while_after_stmts.append(ast.Return(value=yield_args_ast_name_list)) + + while_after_block = ast.copy_location( + ast.FunctionDef( + name=while_after_block_name, + args=block_args, + body=while_after_stmts, + decorator_list=[], + ), + node, + ) + + # Section: Execute via executor + execute_keywords = [ + ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())), + ast.keyword( + arg="write_args", + value=ast.List( + elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], + ctx=ast.Load(), + ), + ), + ast.keyword( + arg="full_write_args_count", + value=ast.Constant(value=full_write_args_count), + ), + ast.keyword( + arg="while_before_block", + value=ast.Name(id=while_before_block_name, ctx=ast.Load()), + ), + ast.keyword( + arg="while_after_block", + value=ast.Name(id=while_after_block_name, ctx=ast.Load()), + ), + ast.keyword( + arg="write_args_names", + value=ast.List( + elts=[ast.Constant(value=arg) for arg in write_args], + ctx=ast.Load(), + ), + ), + ] + + execute_call = ast.Call( + func=self._create_module_attribute( + self.WHILE_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), + args=[], + keywords=execute_keywords, + ) + + # Putting everything together, FunctionDef for while_region + func_args_args = [ast.arg(arg=pred_name, annotation=None)] + func_args_args += [ast.arg(arg=var, annotation=None) for var in write_args] + func_args = ast.arguments( + posonlyargs=[], + args=func_args_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) + + return ast.copy_location( + ast.FunctionDef( + name=func_name, + args=func_args, + body=[ + while_before_block, + while_after_block, + ast.Return(value=execute_call), + ], + decorator_list=[decorator], + ), + node, + ) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9234f2fe760ba0026a63c139b8535dd777f621 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides jit cache load/dump helper functions +""" + +import os +import uuid +import random +import tempfile +import pwd +import time +from pathlib import Path +import hashlib + +from .utils.logger import log +from .jit_executor import JitExecutor + +from .._mlir import ir + +# ============================================================================= +# Jit Cache Helper functions +# ============================================================================= + + +def get_current_user(): + # Try to get the user from the environment variable first + user = os.getenv("USER") or os.getenv("USERNAME") + if not user: + # Fallback for Unix-like systems + user = pwd.getpwuid(os.getuid()).pw_name + return user + + +try: + default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/" +except Exception as e: + # If all else fails, provide a default fallback path + default_generated_ir_path = "/tmp/cutlass_python_cache/" + print(f"Could not determine user, using default path. Error: {e}") + + +def load_ir(file, asBytecode=False): + """Load generated IR from a file.""" + assert "mlir" in file + func_name = file.split(".mlir")[0].split("dsl_")[-1] + with ir.Context() as ctx: + with open(file, "rb" if asBytecode else "r") as f: + module = ir.Module.parse(f.read()) + + return func_name, module + + +def make_unique_filename(fpath: Path, new_ext: str = None) -> Path: + """Generate a unique filename with an optional new extension.""" + random_part = random.randint(0, 999999) + timestamp = time.time() + hash_input = f"{fpath}_{timestamp}_{random_part}".encode() + hash_code = hashlib.md5(hash_input).hexdigest()[:16] # Shorter hash for readability + stem_with_hash = f"{fpath.stem}_{hash_code}" + return fpath.with_name(stem_with_hash).with_suffix(new_ext or fpath.suffix) + + +def save_ir( + dsl_name: str, + module: object, + fname: str, + isTemp: bool = False, + asBytecode: bool = False, +) -> str: + """Save generated IR to a file.""" + initial_name = f"{dsl_name.lower()}_{fname}.mlir" + save_path = Path(tempfile.gettempdir() if isTemp else os.getcwd()) + save_fname = save_path / initial_name + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + pid = os.getpid() + # use temp dir to be robust against program interruptions + temp_dir = os.path.join(save_path, f"tmp.pid_{pid}_{rnd_id}") + # If the process exits abnormally, may leave a temporary folder. Needs to be removed manually. + os.makedirs(temp_dir, exist_ok=False) + temp_fname = os.path.join(temp_dir, initial_name) + + if asBytecode: + with open(temp_fname, "wb") as f: + module.operation.write_bytecode(f) + else: + with open(temp_fname, "w") as f: + print(module, file=f) + # os.replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_fname, save_fname) + os.removedirs(temp_dir) + log().debug("Generated IR saved into %s", save_fname) + return save_fname + + +def check_func_name(jit_cache, func_name): + if not func_name in jit_cache: + jit_cache[func_name] = JitExecutor(None, None, None, None, None, None) + return jit_cache + + +def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path): + """Load cache from a directory path.""" + if not os.path.exists(path): + return dict() + files = os.listdir(path) + jit_cache = dict() + try: + for idx, file in enumerate(files): + if idx >= int(cache_limit): + break + # identify dsl prefix + if not file.startswith(f"{dsl_name.lower()}"): + continue + if ".mlir" in file: + func_name, ir_module = load_ir( + os.path.join(path, file), asBytecode=True + ) + jit_cache = check_func_name(jit_cache, func_name) + jit_cache[func_name].ir_module = ir_module + except Exception as e: + print(f"{dsl_name} failed with loading generated IR cache.", e) + jit_cache = dict() + return jit_cache + + +def dump_cache_to_path( + dsl_name, jit_cache, cache_limit, path=default_generated_ir_path +): + log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache)) + os.makedirs(path, exist_ok=True) + original_path = os.getcwd() + try: + os.chdir(path) + for idx, [key, value] in enumerate(jit_cache.items()): + if idx >= int(cache_limit): + break + save_ir(dsl_name, value.ir_module, key, asBytecode=True) + except Exception as e: + print(f"{dsl_name} failed with caching generated IR", e) + finally: + os.chdir(original_path) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf413ed5018f99ae748cb2eb1883992f27a87b9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import os +from typing import Any, Dict, Iterable, Optional, Union + +""" +This module provides a Exception classes DSL class for any Dialect. +""" + + +# Add color codes at the top of the file after imports +class Colors: + """ANSI color codes for error messages""" + + RED = "\033[91m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + GREEN = "\033[92m" + BOLD = "\033[1m" + RESET = "\033[0m" + + +# ============================================================================= +# DSL Exceptions +# ============================================================================= + + +class DSLBaseError(Exception): + """ + Base exception for DSL-related errors. + Provides optional contextual metadata to aid in debugging. + """ + + def __init__( + self, + message: str, + line: Optional[int] = None, + snippet: Optional[str] = None, + filename: Optional[str] = None, + error_code: Optional[Union[str, int]] = None, + context: Optional[Union[Dict[str, Any], str]] = None, + suggestion: Optional[str] = None, + cause: Optional[BaseException] = None, + ) -> None: + self.message = message + self.line = line + self.filename = filename + self.snippet = snippet + self.error_code = error_code + self.context = context + self.suggestion = suggestion + self.cause = cause + + super().__init__(self._format_message()) + + def _format_message(self): + """ + Formats the complete error message with available metadata. + Override this in subclasses if you want to change formatting logic. + """ + parts = [f"{self.__class__.__name__}: {self.message}"] + + if self.error_code is not None: + parts.append(f"{Colors.BOLD}Error Code:{Colors.RESET} {self.error_code}\n") + + if self.line is not None: + parts.append(f" Line: {self.line}") + + if self.filename is not None: + parts.append(f" File: {self.filename}") + + if self.snippet: + # Optionally truncate long snippets for readability + parts.append(f" Snippet: \n {self.snippet}") + + if self.cause: + parts.append(f" Caused exception: {self.cause}") + + if self.context: + if isinstance(self.context, dict): + parts.append(f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET}\n") + for key, value in self.context.items(): + parts.append(f" {key}: {value}") + else: + parts.append( + f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET} {self.context}" + ) + + if self.suggestion: + parts.append(f"{Colors.GREEN}💡 Suggestions:{Colors.RESET}") + if isinstance(self.suggestion, (list, tuple)): + for suggestion in self.suggestion: + parts.append(f" {Colors.GREEN}{suggestion}{Colors.RESET}") + else: + parts.append(f" {self.suggestion}") + + return "\n".join(parts) + + +class DSLRuntimeError(DSLBaseError): + """ + Raised when an error occurs during JIT-time code generation in the DSL. + """ + + # Inherits all logic from DSLBaseError; override methods if you need + # specialized behavior or formatting for runtime errors. + pass + + +def _get_friendly_cuda_error_message(error_code, error_name): + # Avoid circular dependency + from .runtime.cuda import get_device_info + + """Get a user-friendly error message for common CUDA errors.""" + # Strip the byte string markers if present + if isinstance(error_name, bytes): + error_name = error_name.decode("utf-8") + elif ( + isinstance(error_name, str) + and error_name.startswith("b'") + and error_name.endswith("'") + ): + error_name = error_name[2:-1] + + # Add target architecture info + target_arch = os.getenv("CUTE_DSL_ARCH", "unknown") + + error_messages = { + "CUDA_ERROR_INVALID_SOURCE": ( + f"{Colors.RED}❌ Failed to load CUDA kernel - likely architecture mismatch.{Colors.RESET}\n\n" + ), + "CUDA_ERROR_NO_BINARY_FOR_GPU": ( + f"{Colors.RED}❌ CUDA kernel not compatible with your GPU.{Colors.RESET}\n\n" + ), + "CUDA_ERROR_OUT_OF_MEMORY": ( + f"{Colors.RED}💾 CUDA out of memory error.{Colors.RESET}\n\n" + ), + "CUDA_ERROR_INVALID_DEVICE": ( + f"{Colors.RED}❌ Invalid CUDA device.{Colors.RESET}\n\n" + ), + "CUDA_ERROR_NOT_INITIALIZED": ( + f"{Colors.RED}❌ CUDA context not initialized.{Colors.RESET}\n\n" + ), + "CUDA_ERROR_INVALID_VALUE": ( + f"{Colors.RED}⚠️ Invalid parameter passed to CUDA operation.{Colors.RESET}\n\n" + f"{Colors.YELLOW}This is likely a bug - please report it with:{Colors.RESET}" + ), + } + + error_suggestions = { + "CUDA_ERROR_INVALID_SOURCE": ( + f"1. Ensure env CUTE_DSL_ARCH matches your GPU architecture", + f"2. Clear the compilation cache and regenerate the kernel", + f"3. Check CUDA toolkit installation", + ), + "CUDA_ERROR_NO_BINARY_FOR_GPU": ( + f"Set env CUTE_DSL_ARCH to match your GPU architecture", + ), + "CUDA_ERROR_OUT_OF_MEMORY": ( + f"1. Reduce batch size", + f"2. Reduce model size", + f"3. Free unused GPU memory", + ), + "CUDA_ERROR_INVALID_DEVICE": ( + f"1. Check if CUDA device is properly initialized", + f"2. Verify GPU is detected: nvidia-smi", + f"3. Check CUDA_VISIBLE_DEVICES environment variable", + ), + "CUDA_ERROR_NOT_INITIALIZED": ( + f"1. Check CUDA driver installation", + f"2. call `cuda.cuInit(0)` before any other CUDA operation", + f"3. Run nvidia-smi to confirm GPU status", + ), + "CUDA_ERROR_INVALID_VALUE": ( + f"1. Your GPU model", + f"2. SM ARCH setting", + f"3. Steps to reproduce", + ), + } + + message = error_messages.get( + error_name, f"{Colors.RED}Unknown CUDA error{Colors.RESET}" + ) + + # Add debug information + debug_info = f"\n- {Colors.BOLD}Error name: {error_name}\n" + debug_info += f"- CUDA_TOOLKIT_PATH: {os.getenv('CUDA_TOOLKIT_PATH', 'not set')}\n" + debug_info += ( + f"- Target SM ARCH: {os.getenv('CUTE_DSL_ARCH', 'not set')}{Colors.RESET}\n" + ) + + try: + # Get GPU information using CUDA Python API + debug_info += f"\n{Colors.BLUE}📊 GPU Information:{Colors.RESET}\n" + gpu_info = get_device_info() + debug_info += gpu_info.pretty_str() + + if target_arch and gpu_info.compatible_archs: + debug_info += f"\n{Colors.BOLD}Compatibility Check:{Colors.RESET}\n" + + if target_arch not in gpu_info.compatible_archs: + debug_info += ( + f"{Colors.RED}❌ Error: Target SM ARCH {target_arch} is not compatible\n" + f"💡 Please use one of SM ARCHs: " + f"{Colors.GREEN}{', '.join(gpu_info.compatible_archs or [])}{Colors.RESET}\n" + ) + elif target_arch != gpu_info.sm_arch: + debug_info += ( + f"{Colors.YELLOW}⚠️ Warning: Using compatible but non-optimal architecture\n" + f"• Current: {target_arch}\n" + f"• Recommended: {Colors.GREEN}{gpu_info.sm_arch}{Colors.RESET} (native)\n" + ) + else: + debug_info += f"{Colors.GREEN}✓ Using optimal architecture: {gpu_info.sm_arch}{Colors.RESET}\n" + + except Exception as e: + debug_info += ( + f"\n{Colors.YELLOW}ℹ️ Could not retrieve GPU info: {str(e)}{Colors.RESET}" + ) + + return message, debug_info, error_suggestions.get(error_name, "") + + +class DSLCudaRuntimeError(DSLBaseError): + """ + Raised when an error occurs during CUDA runtime code generation in the DSL. + """ + + # Inherits all logic from DSLRuntimeError; override methods if you need + # specialized behavior or formatting for runtime errors. + def __init__(self, error_code, error_name) -> None: + self._error_code = error_code + self._error_name = error_name + message, debug_info, suggestion = _get_friendly_cuda_error_message( + error_code, error_name + ) + + super().__init__( + message, error_code=error_code, context=debug_info, suggestion=suggestion + ) + + +class DSLAstPreprocessorError(DSLBaseError): + """ + Raised when an error occurs during AST preprocessing or visiting in the DSL. + """ + + # Same approach: You could override _format_message if you want + # to emphasize AST node details or anything specific to preprocessing. + pass + + +class DSLNotImplemented(DSLBaseError): + """ + Raised when a feature of the DSL is not implemented yet. + """ + + # Useful for stubs in your DSL that you plan to implement in the future. + pass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b2da07ac9ac104f56c16a5cfcbbf01f01ee786 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides a class that compiles generated IR using MLIR's PassManager +and executes it using MLIR's ExecutionEngine. + +""" + +from typing import Sequence, Optional, Tuple +import os +import sys +import inspect +import argparse +from .common import DSLRuntimeError +from .utils.logger import log + +_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(_SCRIPT_PATH) + +from .._mlir import ir + + +# ============================================================================= +# Compiler Class +# ============================================================================= + + +class CompilationError(RuntimeError): + """Custom error class for compilation failures""" + + # Add ANSI color codes + RED = "\033[91m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + GREEN = "\033[92m" + BOLD = "\033[1m" + RESET = "\033[0m" + + def __init__( + self, + message: str, + nvvm_error: Optional[str] = None, + ir_context: Optional[str] = None, + cuda_toolkit: Optional[str] = None, + arch: Optional[str] = None, + ): + self.nvvm_error = nvvm_error + self.ir_context = ir_context + self.cuda_toolkit = cuda_toolkit + self.arch = arch + # Call parent with formatted error to avoid showing class name + super().__init__("") # Empty string to avoid class name + # Store formatted error for str() representation + self._formatted_error = self._format_error() + + def __str__(self) -> str: + """Override string representation to avoid showing class name""" + return self._formatted_error + + def __repr__(self) -> str: + """Override repr representation to avoid showing class name""" + return self._formatted_error + + def _format_error(self) -> str: + if not self.nvvm_error: + return str(self.args[0]) + + return f"""NVVM Compilation Error: +---------------------- + +{self.BLUE}⚙️ Current Settings:{self.RESET} +{self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"} +- Target Architecture: {self.arch}{self.RESET} + +IR Context (truncated): +{self.ir_context} + +{self.YELLOW}💡 Possible Solutions:{self.RESET} +{self.GREEN}1. Check if CUDA_TOOLKIT_PATH is set correctly +2. Verify target architecture ({self.arch}) is supported by your CUDA toolkit +3. Make sure CUDA toolkit version matches the target architecture requirements{self.RESET}""" + + +class Compiler: + """Compiler class for compiling and building MLIR modules.""" + + def __init__(self, passmanager, execution_engine): + self.passmanager = passmanager + self.execution_engine = execution_engine + + def __call__(self, module): + """Convenience application method.""" + self.compile(module) + + def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]: + """Process error message to extract NVVM error and IR context""" + nvvm_error = None + ir_msg = "" + + if "NVVM_ERROR" in error_msg: + # Extract the specific NVVM error + nvvm_error = ( + error_msg.split("libNVVM extra log:")[1].strip() + if "libNVVM extra log:" in error_msg + else error_msg + ) + + # Extract IR context + if "see current operation:" in error_msg: + # Get the IR section + ir_section = error_msg.split("see current operation:")[1].strip() + # Remove duplicate IR section + ir_section = ir_section.split("error: unknown: Failed translating")[ + 0 + ].strip() + + # Get first few lines and last few lines of the IR + ir_lines = ir_section.split("\n") + if len(ir_lines) > 10: + ir_msg = "\n".join(ir_lines[:5] + [" ..."] + ir_lines[-5:]) + else: + ir_msg = ir_section + + return nvvm_error, ir_msg + + def compile( + self, + module, + pipeline: str, + cuda_toolkit: str = "", + arch: str = "", + enable_verifier=False, + ): + """Compiles the module by invoking the pipeline.""" + try: + pm = self.passmanager.PassManager.parse(pipeline) + pm.enable_verifier(enable_verifier) + pm.run(module.operation) + except Exception as e: + error_msg = str(e) + nvvm_error, ir_msg = self._process_error(error_msg) + + if nvvm_error: + raise CompilationError( + error_msg, + nvvm_error=nvvm_error, + ir_context=ir_msg, + cuda_toolkit=cuda_toolkit, + arch=arch, + ) from e + raise e + + def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()): + """Wraps the module in a JIT execution engine.""" + return self.execution_engine.ExecutionEngine( + module, opt_level=opt_level, shared_libs=shared_libs + ) + + def compile_and_jit( + self, + module, + pipeline: str, + shared_libs: Sequence[str] = (), + opt_level: int = 2, + cuda_toolkit: str = "", + arch: str = "", + ): + """Compiles and jits the module.""" + self.compile( + module, + pipeline, + cuda_toolkit, + arch, + ) + return self.jit(module, opt_level, shared_libs) + + +class CompileOptions: + def __init__(self, options: str = ""): + """ + This class encapsulates all compilation options relevant to function compilation. + It provides a convenient way to manage and pass compilation options, + particularly for controlling compilation settings. + By centralizing these options, it ensures consistent and flexible configuration of + compilation parameters such as optimization level, debugging control, etc. + + :param options: The options for the function. Will be parsed by argparse. + :type options: str + """ + if not isinstance(options, str): + raise DSLRuntimeError( + f"Invalid compilation `options`: {options}, it should be a string" + ) + self._parser = argparse.ArgumentParser() + self._parser.add_argument("--opt-level", nargs="?", type=int, default=3) + self._parser.add_argument( + "--enable-device-assertions", action="store_true", default=False + ) + self._parser.add_argument("--link-libraries", type=str, default="") + + try: + self._options = self._parser.parse_args(options.split()) + except SystemExit as e: + # catch argparse error and raise as DSLRuntimeError + raise DSLRuntimeError( + f"Invalid compile options: '{options}'. Please check the option values and format." + ) + log().info("`cute.compile` CompileOptions: options=" + options) + + def to_str(self): + """ + Generate a string representation of all compilation options + which will be used in pipeline options. + """ + option_strings = [] + for key, value in vars(self._options).items(): + hyphen_key = key.replace("_", "-") + if isinstance(value, bool): + formatted_value = "true" if value else "false" + else: + formatted_value = str(value) + option_strings.append(f"{hyphen_key}={formatted_value}") + + return " ".join(option_strings) + + +def compile(func, *args, **kwargs): + """ + This function is used to compile a `cute.jit` decorated function. + It will process the compile options and input parameters, do explicit compilation and return the jit executor. + + :param func: The function to compile. It can be a regular function, a method or a class instance. + :param args: The arguments to pass to the function. + :param kwargs: The keyword arguments to pass to the function. It can contain `options` like + `opt_level` to control the compilation flags. + + :return: The jit executor. + + :raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable. + """ + if func is None: + raise DSLRuntimeError("Function is not set or invalid.") + + if not callable(func): + raise DSLRuntimeError("Object is not callable.") + + kwargs["compile_only"] = True + kwargs["no_cache"] = True + + if inspect.isfunction(func): + # regular function + pass + elif inspect.ismethod(func): + # if it's a method, add the instance to the first argument + args = [func.__self__] + list(args) + func = func.__func__ + elif inspect.isclass(type(func)) and hasattr(func, "__call__"): + # If it's a class instance, get the class's __call__ method + args = [func] + list(args) + # Get the actual function from the class definition + func = func.__call__.__func__ + else: + raise DSLRuntimeError( + "Invalid function type, only function, method and module are supported, but got", + func, + ) + + # If it's a wrapped function created by jit decorator, get the original function + if hasattr(func, "__wrapped__"): + func = func.__wrapped__ + + if not hasattr(func, "_dsl_object"): + raise DSLRuntimeError("Function is not decorated with jit decorator.") + + # process compile options, extract the options and remove them from the kwargs + options = kwargs.pop("options", "") + func._dsl_object.compile_options = CompileOptions(options) + fcn_ptr = func._dsl_object._preprocess_and_execute(func) + return func._dsl_object._func(fcn_ptr, *args, **kwargs) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py new file mode 100644 index 0000000000000000000000000000000000000000..2b17d22b1e6d7157a7f14334b0f29f1386c58c15 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py @@ -0,0 +1,1686 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides a main DSL class for any Dialect. +The DSL should be inherited as a new class, and its initialization requires dialects. +It handles most of the mechanics for the DSL in an agnostic way, +for example, it can handle various dialect-specific tasks. +""" + + +# Standard library imports +from dataclasses import dataclass, field +import atexit +import os +import io +import sys +import errno +import ctypes +import re +import inspect +import argparse +import hashlib +from functools import lru_cache, wraps +from collections import namedtuple +from abc import ABC, abstractmethod +from typing import Any, Union, Tuple, get_origin, get_args, List +from types import FunctionType, SimpleNamespace +import warnings + +from . import typing as t +from .env_manager import EnvironmentVarManager +from .compiler import CompileOptions +from .ast_helpers import DSLOptimizationWarning + +# ============================================================================= +# CUDA Python +# ============================================================================= + +from ..base_dsl._mlir_helpers.arith import const + +# ============================================================================= +# Local module imports +# ============================================================================= + +from .cache_helpers import * +from .jit_executor import JitExecutor +from .utils.timer import timer +from .utils.logger import setup_log, log +from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe +from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry + +from .ast_preprocessor import DSLPreprocessor +from .common import * +from .typing import ( + get_c_pointers, + get_mlir_types, +) + +# ============================================================================= +# MLIR modules +# ============================================================================= + +from .._mlir import ir +from .._mlir import runtime as rt +from .._mlir.extras import types as T +from .._mlir.dialects import arith, math, func + +# ============================================================================= +# Global Variables +# ============================================================================= + +MLIR_DYNAMIC = -9223372036854775808 + +# ============================================================================= +# Codegen Utils +# ============================================================================= + + +def _numpy_type_to_mlir_type(dtype): + if dtype == np.float64: + return T.f64() + if dtype == np.float16: + return T.f16() + if dtype == np.float32: + return T.f32() + if dtype == np.int64: + return T.i64() + if dtype == np.int32: + return T.i32() + if dtype == np.int16: + return T.i16() + if dtype == np.int8: + return T.i8() + if dtype == np.uint64: + return T.ui64() + if dtype == np.uint32: + return T.ui32() + if dtype == np.uint16: + return T.ui16() + if dtype == np.uint8: + return T.ui8() + if dtype == np.bool_: + return T.bool() + if dtype == f8E5M2: + return T.f8E5M2() + if dtype == f8E4M3FN: + return T.f8E4M3FN() + if dtype == f8E8M0FNU: + return T.f8E8M0FNU() + if dtype == f6E3M2FN: + return T.f6E3M2FN() + if dtype == f6E2M3FN: + return T.f6E2M3FN() + if dtype == f4E2M1FN: + return T.f4E2M1FN() + assert False, f"Unknown type {type}" + + +def _mlir_type_to_numpy_type(type): + if type == T.f64(): + return np.float64 + if type == T.f16(): + return np.float16 + if type == T.f32(): + return np.float32 + if type == T.i64(): + return np.int64 + if type == T.i32(): + return np.int32 + if type == T.i16(): + return np.int16 + if type == T.i8(): + return np.int8 + if type == T.ui64(): + return np.uint64 + if type == T.ui32(): + return np.uint32 + if type == T.ui16(): + return np.uint16 + if type == T.ui8(): + return np.uint8 + if type == T.bool(): + return np.bool_ + assert False, f"Unknown type {type}" + + +# ============================================================================= +# Main DSL Class +# ============================================================================= + + +def is_dynamic_expression(value): + """ + Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value + """ + if isinstance(value, (tuple, list)): + for x in value: + if is_dynamic_expression(x): + return True + elif isinstance(value, (ir.Value, ir.BlockArgumentList)) or hasattr( + value, "__extract_mlir_values__" + ): + return True + return False + + +def extract_mlir_values(obj): + """ + Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values + """ + res = [] + if hasattr(obj, "__extract_mlir_values__"): + res = obj.__extract_mlir_values__() + elif isinstance(obj, (tuple, list)): + res = sum((extract_mlir_values(x) for x in obj), []) + elif isinstance(obj, SimpleNamespace): + res = [] + for k, v in obj.__dict__.items(): + res.extend(extract_mlir_values(v)) + # Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in extract_mlir_values to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + elif isinstance(obj, ir.Value): + res = [obj] + elif isinstance(obj, ir.BlockArgumentList): + res = list(obj) # type: ignore + + return res + + +def new_from_mlir_values(obj, values): + """ + Create a new python object by populating containing MLIR values with list of new values + """ + if hasattr(obj, "__new_from_mlir_values__"): + return obj.__new_from_mlir_values__(values) + elif isinstance(obj, (tuple, list)): + res = [] + for x in obj: + n_items = len(get_mlir_types(x)) + res.append(new_from_mlir_values(x, values[:n_items])) + values = values[n_items:] + obj_ty = type(obj) + return obj_ty(res) + elif isinstance(obj, SimpleNamespace): + res = SimpleNamespace() + for k, v in obj.__dict__.items(): + n_items = len(get_mlir_types(v)) + res.__dict__[k] = new_from_mlir_values(v, values[:n_items]) + values = values[n_items:] + return res + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in new_from_mlir_values to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + elif is_dynamic_expression(obj): + + if len(values) == 0: + return obj + + assert len(values) == 1 + return values[0] + else: + assert len(values) == 0, f"{obj} expects 0 values, but got {values}" + return obj + + +class DSLCallable: + """ + Wrapper class for a callable object used within the DSL. + + DSLCallable is designed to wrap a function and provide additional + introspection utilities such as retrieving the argument specification + and signature. It ensures that the wrapped function can only be called + once, after which the reference to the function is cleared to prevent + further invocations. This is useful in scenarios where a function should + only be executed a single time within the DSL's execution model. + + Attributes: + func (callable): The function to be wrapped and managed. + + Methods: + __call__(*args, **kwargs): Calls the wrapped function and clears it. + """ + + def __init__(self, func): + self.func = func + + def __call__(self, *args, **kwargs): + ret = self.__func__(*args, **kwargs) + self.func = None + return ret + + @property + def __func__(self): + assert self.func is not None, "DSLCallable is already called" + return self.func + + @property + def __signature__(self): + return inspect.signature(self.__func__) + + @property + def __name__(self): + return self.__func__.__name__ + + +class BaseDSL: + gpu_module = None + + def __init__( + self, + *, + name: str, + dsl_package_name: List[str], + compiler_provider: Any, + pass_sm_arch_name: str, + device_compilation_only=False, + preprocess=False, + ): + """ + Constructor for initializing the class with required providers and environment settings. + + Parameters: + - name (str): Name of DSL, used for environment variables and logging. + - package_name (str): Name of the package, used for the preprocessor. + - compiler_provider (MLIR dialect): Provider for compiler. + - pass_sm_arch_name (str): The keyword name of the SM. + - device_compilation_only (bool) : Only device code, and call it via cuda driver + - preprocess (bool): Enable AST transformation. + + This constructs a DSL instance and sets up environment management, + warning configurations, and logging functionalities. It reads + environment variables using `EnvironmentVarManager` and configures + a logger with settings from the environment. If environment warnings + are detected, they are escalated to errors to ensure strict handling. + """ + # Enforcing initialization of instance variables + if not all([name, compiler_provider, pass_sm_arch_name]): + raise DSLRuntimeError( + "All required parameters must be provided and non-empty" + ) + + self.name = name + self.compiler_provider = compiler_provider + self.pass_sm_arch_name = pass_sm_arch_name + self.frame = None + self.no_cache = False + self.device_compilation_only = device_compilation_only + self.num_kernels = 0 + # Read environment variables + self.envar = EnvironmentVarManager(self.name) + self.enable_preprocessor = preprocess + # This cache uses hash of original ir and env as key, allows dump/load to/from file. Enabled by default + self.jit_cache = ( + dict() + if self.envar.disable_file_caching + else load_cache_from_path(self.name, self.envar.file_caching_capacity) + ) + self.host_jit_decorator_name = f"@{BaseDSL.jit.__name__}" + self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}" + + # set warning + if not self.envar.enable_optimization_warnings: + # By default, optimization warnings are disabled + warnings.filterwarnings("ignore", category=DSLOptimizationWarning) + if self.envar.warnings_as_errors: + warnings.filterwarnings("error") + if self.envar.warnings_ignore: + warnings.filterwarnings("ignore") + + # Initialize logger + if self.envar.log_to_console == False and self.envar.jitTimeProfiling: + self.envar.log_to_console = True + self.envar.log_level = 20 # info level + setup_log( + self.name, + self.envar.log_to_console, + self.envar.log_to_file, + f"{self.name}.log", + self.envar.log_level, + ) + + # kernel symbols are temporary symbol string variables, their values are valid until the compilation is done. + self.kernel_symbols = [] + # used to generate unique name for gpu.launch + self.launch_inner_count = 0 + # initialize default compile options + self.compile_options = CompileOptions() + + if preprocess: + self.preprocessor = DSLPreprocessor(dsl_package_name) + log().info(f"Initializing {name} DSL") + log().debug(f"Logger initialized for {self.name}") + + # Hook excepthook + if self.envar.filterStacktrace: + origin_excepthook = sys.excepthook + module_dir = walk_to_top_module(os.path.dirname(os.path.abspath(__file__))) + + def excepthook(excep_type, value, traceback): + filter_exception(value, module_dir) + if hasattr(value, "__traceback__"): + origin_excepthook(excep_type, value, value.__traceback__) + else: + origin_excepthook( + excep_type, value, filter_stackframe(traceback, module_dir) + ) + + sys.excepthook = excepthook + + # Restore original excepthook + def restore_excepthook(hook): + sys.excepthook = hook + + atexit.register(restore_excepthook, origin_excepthook) + + def dump_cache(self): + if not self.envar.disable_file_caching: + dump_cache_to_path( + self.name, self.jit_cache, self.envar.file_caching_capacity + ) + + @lru_cache(maxsize=1) + def print_warning_once(self, message): + log().warning(f"Warning: {message}") + warnings.warn(message, UserWarning) + + def print_warning(self, message): + log().warning(f"Warning: {message}") + warnings.warn(message, UserWarning) + + @classmethod + @lru_cache(maxsize=1) + def _get_dsl(cls): + # Instantiate the DSL Class once + main_dsl = cls() + if not main_dsl.no_cache: + # register atexit callback + atexit.register(main_dsl.dump_cache) + return main_dsl + + @staticmethod + def _can_preprocess(**dkwargs): + """ + Check if AST transformation is enabled or not for `jit` and `kernel` decorators. + """ + return dkwargs.pop("preprocess", True) + + @staticmethod + def _get_original_function(fcn_ptr, name): + """ + Get the original function from the decorated function + """ + while fcn_ptr.__name__ != name: + # If the function is wrapped with functools, get from __wrapped__ + if hasattr(fcn_ptr, "__wrapped__"): + fcn_ptr = fcn_ptr.__wrapped__ + # If the function is wrapped manually, it's the first in clousure + elif callable(fcn_ptr.__closure__[0].cell_contents): + fcn_ptr = fcn_ptr.__closure__[0].cell_contents + else: + raise DSLRuntimeError( + f"Cannot find the original function {name} in the closure chain" + ) + return fcn_ptr + + @staticmethod + def _preprocess_and_execute(func): + """ + Run ast transformation and return the materialized function pointer + """ + if hasattr(func, "_transformed_ast"): + # If the function ptr is already materialized, use the existing one + func._dsl_object.frame = func._decorator_frame + if func._transformed_ast is None: + func._transformed_ast = func._dsl_object.run_preprocessor(func) + if func._transformed_ast is None: + del func._transformed_ast + func._dsl_object.frame = None + return func + + fcn_ptr = func._dsl_object.get_function_ptr(func) + # If the function is decorated, de-decorate it + fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__) + func._dsl_object.frame = None + return DSLCallable(fcn_ptr) + return func + + def jit_runner(self, executor, frame, *dargs, **dkwargs): + """ + Decorator to mark a function for JIT compilation. + """ + log().info("jit_runner") + + def jit_runner_decorator(func): + func._dsl_object = self + # Run preprocessor that alters AST + if self.enable_preprocessor and BaseDSL._can_preprocess(**dkwargs): + # For an annotated function, add some DSL attributes + # When materializing the AST, we need decorator's frame + func._decorator_frame = frame + # No transformed ast at this point + func._transformed_ast = None + + @wraps(func) + def jit_wrapper(*args, **kwargs): + func_ptr = BaseDSL._preprocess_and_execute(func) + return executor(func_ptr, *args, **kwargs) + + return jit_wrapper + + if len(dargs) == 1 and callable(dargs[0]): + return jit_runner_decorator(dargs[0]) + else: + return jit_runner_decorator + + @classmethod + def jit(cls, *dargs, **dkwargs): + """ + Decorator to mark a function for JIT compilation for Host code. + """ + frame = inspect.currentframe().f_back + # Instantiate the DSL Class + main_dsl = cls._get_dsl() + return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs) + + @classmethod + def kernel(cls, *dargs, **dkwargs): + """ + Decorator to mark a function for JIT compilation for GPU. + """ + frame = inspect.currentframe().f_back + # Instantiate the DSL Class + main_dsl = cls._get_dsl() + return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs) + + @abstractmethod + def _kernel_helper(self, func, *args, **kwargs): + """ + Helper function to handle kernel generation logic + """ + pass + + @abstractmethod + def _build_gpu_module(self, attrs): + """ + Build the module op that contains the kernels. + """ + pass + + @abstractmethod + def _get_pipeline(self, pipeline): + """ + Get the pipeline from the other configuration options. + """ + if pipeline != None: + return pipeline + return None + + @staticmethod + def log_additions(func_type, operands=None, types=None, arg_attrs=None): + if operands is not None and operands != []: + log().debug( + f"Added {func_type} operands: [%s]", ", ".join(map(str, operands)) + ) + if types is not None: + log().debug( + f"Added {func_type} arg_types: [%s]", ", ".join(map(str, types)) + ) + if arg_attrs is not None: + log().debug( + f"Added {func_type} arg_attrs: [%s]", ", ".join(map(str, arg_attrs)) + ) + + def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec): + """Does simple name mangling""" + + for spec_arg, arg in zip(args_spec.args, args): + spec_ty = args_spec.annotations.get(spec_arg, None) + if spec_ty != None: + if issubclass(type(spec_ty), (t.IRValue, t.IRVariadic)): + continue + if isinstance(spec_ty, (ir.Type, ir.Value)): + continue + if isinstance(arg, (ir.Type, ir.Value, ir.OpResult)): + continue + if isinstance(type(arg), (ir.Type, ir.Value, ir.OpResult)): + continue + if self._is_tensor_descriptor(arg): + continue + if inspect.isclass(spec_ty): + class_name = str(arg).replace("class", "") + class_name = class_name.replace(" ", "") + function_name = f"{function_name}_{class_name}" + elif isinstance(arg, (list, tuple)): + function_name = f"{function_name}_{'_'.join(map(str, arg))}" + else: + function_name = f"{function_name}_{arg}" + # we would need a dedicated MR to follow up + unwanted_chars = r"'-![]#,.<>()\":{}=%?@;" + translation_table = str.maketrans("", "", unwanted_chars) + function_name = function_name.translate(translation_table) + # identify address and drop + function_name = re.sub(r"0x[a-f0-9]{8,16}", "", function_name) + function_name = re.sub(r"\s+", " ", function_name) + function_name = function_name.replace(" ", "_") + function_name = function_name.replace("\n", "_") + # max fname is 256 character, leave space + function_name = function_name[:180] + log().info(f"Final mangled function name: {function_name}") + return function_name + + def _generate_execution_arguments_for_known_types( + self, arg, arg_spec, arg_name, i, fop_args, iv_block_args + ): + """ + Generate MLIR arguments for known types. + + Sub-DSLs can override this method to handle types that are not + natively supported by the Base DSL. + """ + ir_arg = [] + if is_argument_constexpr(arg, arg_spec, arg_name, i, func): + ir_arg.append(arg) + + return ir_arg, iv_block_args + + def generate_execution_arguments( + self, + args, + kwargs, + fop, + args_spec: inspect.FullArgSpec, + ): + """Create list of arguments that will be passed to MLIR's func.func op""" + + def gen_exec_args(input_args, arg_names, annotations, fop_args): + assert len(input_args) == len(arg_names) + + ir_args = [] + iv_block_args = 0 + for i, arg in enumerate(input_args): + arg_name = arg_names[i] + arg_spec = annotations.get(arg_name, None) + log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec) + + # Implicit cast to NumericMeta + if isinstance(arg_spec, t.NumericMeta) and not isinstance( + arg, arg_spec + ): + arg = t.cast(arg, arg_spec) + + ir_arg, iv_block_args = ( + self._generate_execution_arguments_for_known_types( + arg, arg_spec, arg_name, i, fop_args, iv_block_args + ) + ) + + if not ir_arg: + # If it's not a known type, try JIT argument adapter + # to convert the argument if possible + adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) + arg = adapter(arg) if adapter else arg + + n_args = len(get_mlir_types(arg)) + blk_args = fop_args[iv_block_args : iv_block_args + n_args] + ir_arg.append(new_from_mlir_values(arg, blk_args)) + iv_block_args += n_args + + self.log_additions(ir_arg) + ir_args.extend(ir_arg) + + return ir_args, iv_block_args + + fop_args = list(fop.regions[0].blocks[0].arguments) + ir_args, iv_block_args = gen_exec_args( + args, args_spec.args, args_spec.annotations, fop_args + ) + ir_kwargs, _ = gen_exec_args( + [kwargs[arg] for arg in args_spec.kwonlyargs], + args_spec.kwonlyargs, + args_spec.annotations, + fop_args[iv_block_args:], + ) + ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)} + + log().debug("execution args: %s", ", ".join(map(str, ir_args))) + log().debug("execution kwargs: %s", ", ".join(map(str, ir_kwargs))) + return ir_args, ir_kwargs + + @abstractmethod + def _generate_mlir_type_for_tensor_descriptor(self, tensor): + """ + Generate MLIR type for the tensor descriptor. + """ + pass + + @abstractmethod + def _generate_executable_arg_for_tensor_descriptor( + self, mlir_value=None, ptr_tensor_ty=None, tensor=None + ): + """ + Generates executable value for the given tensor descriptor. + """ + pass + + def _get_globals(self): + """ + Combines global and local variables from the current context and the + caller's frame comes. This includes the current module's globals, the + global variables from the caller's frame, and the local variables from + the caller's frame. + + "self.frame" is used to fetch the caller's frame. + + AST preprocessor generates a new python code, so the resulting globals + dictionary is used to execute the python code. + """ + all_globals = {} + if self.frame: + all_globals.update(self.frame.f_globals) + all_globals.update(self.frame.f_locals) + return all_globals + + @abstractmethod + def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: + pass + + @abstractmethod + def _handle_tensor_descriptor( + self, maybe_tensor, arg_name: str, need_gpu_memory: bool + ) -> Any: + pass + + def _validate_arg(self, arg, arg_index, arg_name, arg_spec): + """ + Validates if the arg is really of the annotated type for type safety. + + The default implementation is empty. Subclasses can override this method to add more validation logic. + Returns None if validation passes, otherwise returns an error derived from DSLBaseError. + """ + pass + + def _generate_jit_func_args_for_known_types( + self, + func, + arg, + arg_name, + arg_spec, + arg_index, + *, + is_host=True, + ): + """ + Generate JIT function arguments for known types. + + Sub-DSLs can override this method to handle types that are not + natively supported by the Base DSL. + """ + + jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], [] + default_attr = ir.DictAttr.get({}) + + if is_argument_constexpr(arg, arg_spec, arg_name, arg_index, func): + jit_exec_arg = jit_arg_type = jit_arg_attr = None + + return jit_exec_arg, jit_arg_type, jit_arg_attr + + def _generate_jit_func_args( + self, + func, + function_name, + args, + kwargs, + args_spec: inspect.FullArgSpec, + *, + is_host=True, + ): + """Generate JIT function arguments.""" + + assert len(args) == len(args_spec.args) and len(kwargs) == len( + args_spec.kwonlyargs + ), ( + f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args " + f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}" + ) + + jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], [] + jit_adapted_args = [] + default_attr = ir.DictAttr.get({}) + + input_args = [*args, *kwargs.values()] + input_arg_names = [*args_spec.args, *args_spec.kwonlyargs] + for i, (arg_name, arg) in enumerate(zip(input_arg_names, input_args)): + spec_ty = args_spec.annotations.get(arg_name, None) + log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty) + + # Implicitly convert into Numeric type if possible + if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty): + arg = t.cast(arg, spec_ty) + + # Type safety check + if spec_ty is not None: + err = self._validate_arg(arg, i, arg_name, spec_ty) + if err is not None: + raise err + + jit_exec_arg, jit_arg_type, jit_arg_attr = ( + self._generate_jit_func_args_for_known_types( + func, + arg, + arg_name, + spec_ty, + i, + is_host=is_host, + ) + ) + + if jit_arg_type is not None and len(jit_arg_type) == 0: + # If not any known type, try JIT argument adapter + # to convert the argument + adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) + if adapter: + arg = adapter(arg) + jit_adapted_args.append(arg) + + if is_host: + jit_exec_arg.extend(get_c_pointers(arg)) + jit_arg_type.extend(get_mlir_types(arg)) + else: + dyn_vals = extract_mlir_values(arg) + jit_exec_arg.extend(dyn_vals) + jit_arg_type.extend([v.type for v in dyn_vals]) + + if not jit_arg_type or not jit_exec_arg: + if (is_host and hasattr(arg, "__c_pointers__")) or ( + not is_host + and hasattr(arg, "__extract_mlir_values__") + and hasattr(arg, "__new_from_mlir_values__") + ): + pass + else: + raise DSLRuntimeError( + f"failed to generate argument #{i+1} ({arg_name}) for JIT function '{function_name}'.", + context={ + f"Argument {arg_name}": "The DSL attempted to convert it into Dynamic Expression (aka MLIR values) but failed.", + f"Call-site argument value": arg, + f"Call-site argument type": type(arg), + }, + suggestion=f"Consider annotating the argument with `{arg_name} : Constexpr` " + "if it's a value known at compile-time. " + f"Otherwise, implement the {'`JitArgument`' if is_host else '`DynamicExpression`'} " + f"protocol or register a custom JIT argument adapter for type `{type(arg)}` to " + "enable dynamic value conversion at runtime.", + ) + + jit_arg_attr.extend([default_attr] * len(jit_arg_type)) + + if jit_arg_type is not None: + jit_exec_args.extend(jit_exec_arg) + jit_arg_types.extend(jit_arg_type) + jit_arg_attrs.extend(jit_arg_attr) + + return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args + + def generate_mlir_function_types( + self, func, function_name, input_args, kwargs, args_spec: inspect.FullArgSpec + ): + """Convert input arguments to MLIR function signature also convert numpy arrays to memref.""" + + exe_args, types, attrs, adapted_args = self._generate_jit_func_args( + func, function_name, input_args, kwargs, args_spec, is_host=True + ) + + log().debug("Execution Arguments: %s", ", ".join(map(str, exe_args))) + log().debug("Types: %s", ", ".join(map(str, types))) + + assert len(exe_args) == len( + types + ), "expects the same number of arguments and function parameters" + + return exe_args, types, adapted_args + + @dataclass + class LaunchConfig: + cluster: list = None + grid: list = field(default_factory=lambda: [1, 1, 1]) + block: list = field(default_factory=lambda: [1, 1, 1]) + smem: int = None + async_deps: list = field(default_factory=list) + has_cluster: bool = False + min_blocks_per_mp: int = 0 + auto_smem: bool = False + + def __post_init__(self): + if len(self.grid) != 3: + raise DSLRuntimeError(f"Expect 3d grid!") + if len(self.block) != 3: + raise DSLRuntimeError(f"Expect 3d block!") + + if self.smem is None: + self.smem = 0 + self.auto_smem = True + + self.has_cluster = self.cluster is not None + if self.cluster is None: + self.cluster = [None, None, None] + elif len(self.cluster) != 3: + raise DSLRuntimeError(f"Expect 3d cluster!") + + def diagnostic(self): + """Check command line parameters and enables diagnostic""" + # Check command line arguments "-diagnostic" + parser = argparse.ArgumentParser(description="Process diagnostic status.") + parser.add_argument( + "-diagnostic", + nargs="?", + const="all", + choices=["all", "fail", "success", "info", "suggestion"], + help="Set diagnostic status (fail, success, info, suggestion).", + ) + + args, _ = parser.parse_known_args() + ctx = ir.Context.current + + def callback(d): + print(f" [{self.name} Diagnostic] : {d.message}") + + ctx.attach_diagnostic_handler(callback) + + # Early return, don't enable diagnostics + if args.diagnostic is None: + return + + # Enable MLIR Flags + ctx.emit_error_diagnostics = True + ir._GlobalDebug.flag = True + if args.diagnostic == "all": + ir._GlobalDebug.set_types("diagnostic") + else: + ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}") + + def get_location(self): + """ + Get python location information and generate MLIR location + """ + + if self.frame is None: + log().debug("Frame is None") + return None + + file_loc = ir.Location.file( + self.frame.f_code.co_filename, self.frame.f_lineno, 0 + ) + + loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc) + return loc + + def compile_and_jit(self, module, pipeline, shared_libs, function_name=""): + """ + Compile and JIT an MLIR module. + """ + + try: + self.diagnostic() + + orig_stdout = sys.stdout + orig_stderr = sys.stderr + sys.stderr = redirect_stderr = io.StringIO() + sys.stdout = redirect_stdout = io.StringIO() + + try: + kernel = self.compiler_provider.compile_and_jit( + module, + pipeline, + shared_libs=shared_libs, + cuda_toolkit=self.envar.cuda_toolkit, + arch=self.envar.arch, + ) + + finally: + sys.stdout = orig_stdout + sys.stderr = orig_stderr + ir._GlobalDebug.flag = False + + # Print captured output. + print(redirect_stdout.getvalue(), file=sys.stdout, end="") + print(redirect_stderr.getvalue(), file=sys.stderr, end="") + + return kernel + + except Exception as e: + raise DSLRuntimeError("🧊🧊🧊 ICE 🧊🧊🧊", cause=e) + finally: + pass + + def preprocess_pipeline(self, pipeline, arch) -> str: + + if self.envar.cuda_toolkit is None: + self.print_warning( + "CUDA_TOOLKIT_PATH environment variable is not set. Cannot set toolkitPath." + ) + + options = { + "toolkitPath": self.envar.cuda_toolkit if self.envar.cuda_toolkit else None, + self.pass_sm_arch_name: arch, + } + + opt_str = "" + for k, v in options.items(): + if v: + opt_str += f"{k}={v} " + + if opt_str: + # Automatically append the pipeline options if any is specified through env var + pattern = re.compile(r"{(.+)}") + match = pattern.search(pipeline) + if match: + opt_str = f"{{{match[1]} {opt_str}}}" + pipeline = re.sub(r"{.+}", opt_str, pipeline) + else: + pipeline = pipeline.rstrip(")") + f"{{{opt_str}}})" + log().debug(f"Using pipeline = {pipeline}") + return pipeline + + def get_shared_libs(self) -> list: + shared_libs = [] + support_libs = self.envar.shared_libs + if support_libs is not None: + _libs = support_libs.split(":") + for lib in _libs: + if not os.path.exists(lib): + raise FileNotFoundError( + errno.ENOENT, os.strerror(errno.ENOENT), lib + ) + shared_libs.append(lib) + else: + self.print_warning(f"{self.name}_LIBS environment variable is not set") + + return shared_libs + + @lru_cache(maxsize=1) + def get_version(self): + version_hash = hashlib.sha256() + + return version_hash + + def get_module_hash(self, module, function_name): + s = io.BytesIO() + module.operation.write_bytecode(s) + for attr, value in self.envar.__dict__.items(): + if value is not None: + s.write(str(value).encode()) + # Add compile options to the hash + s.write(self.compile_options.to_str().encode()) + module_hash = self.get_version().copy() + module_hash.update(s.getvalue()) + module_hash = module_hash.hexdigest() + + log().debug("Bytecode=[%s]", s.getvalue().hex()) + log().debug("Version=[%s]", self.get_version().hexdigest()) + log().info( + "Function=[%s] Computed module_hash=[%s]", function_name, module_hash + ) + return module_hash + + def build_module(self, module, function_name: str): + """ + Build the MLIR module, verify and return the module + """ + + # Save IR in a file + if self.envar.keepIR: + save_ir(self.name, module, function_name) + + if self.envar.printIR: + print("\n//===--- ------ Generated IR ------ ---====\n") + module.operation.print( + enable_debug_info=self.envar.generate_source_location + ) + print("\n//===--- --- End of Generated IR -- ---====\n") + + # Verify the module + try: + module.operation.verify() + except Exception as e: + raise DSLRuntimeError(f"🧊🧊🧊 ICE IR Verification Failed 🧊🧊🧊", cause=e) + + return module + + def generate_original_ir( + self, + ir, + func, + funcBody, + kwargs, + function_name, + func_types, + gpu_module_attrs, + args, + args_spec, + ): + # This location is set to None for now; otherwise, calls to the same + # function on different lines would produce different line numbers, + # which would break the cache. + loc = None # self.get_location() + + def build_ir_module(): + module = ir.Module.create(loc=loc) + unit_attr = ir.UnitAttr.get() + module.operation.attributes["gpu.container_module"] = unit_attr + + with ir.InsertionPoint(module.body): + # Always generate gpu module. It's canonicalized by the compiler when it's not used. + self._build_gpu_module(gpu_module_attrs) + + fop = func.FuncOp(function_name, (func_types, []), loc=loc) + fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + log().debug("Generated Function OP [%s]", fop) + with ir.InsertionPoint(fop.add_entry_block()): + ir_args, ir_kwargs = self.generate_execution_arguments( + args, kwargs, fop, args_spec + ) + # Call user function body + try: + result = funcBody(*ir_args, **ir_kwargs) + func.ReturnOp([]) + except NameError as name_error: + raise DSLRuntimeError( + f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥", + cause=name_error, + suggestion="Using variables defined in dynamic control flow is not supported. Please give an initial value before control flow.", + ) + except DSLRuntimeError as dsl_error: + # Throw it's already a DSL error + raise dsl_error + return module, result + + # Build IR module + profiler = timer(enable=self.envar.jitTimeProfiling) + module, result = profiler(build_ir_module)() + module_hash = self.get_module_hash(module, function_name) + + module = self.build_module(module, function_name) + + return module, module_hash, result + + def compile_and_cache( + self, module, module_hash, function_name, pipeline, args_spec, no_cache + ): + arch = self.envar.arch + pipeline = self.preprocess_pipeline(self._get_pipeline(pipeline), arch) + shared_libs = self.get_shared_libs() + profiler = timer(enable=self.envar.jitTimeProfiling) + if ( + no_cache + or module_hash not in self.jit_cache + or self.jit_cache[module_hash].ir_module is None + ): + log().info( + "JIT cache miss function=[%s] module_hash=[%s]", + function_name, + module_hash, + ) + # Compile and JIT MLIR module + engine = profiler(self.compile_and_jit)( + module, pipeline, shared_libs, function_name=function_name + ) + else: + log().info( + "JIT cache hit IN-FILE function=[%s] module_hash=[%s]", + function_name, + module_hash, + ) + module = self.jit_cache[module_hash].ir_module + engine = self.compiler_provider.jit(module, shared_libs=shared_libs) + capi_func = profiler(engine.lookup)(function_name) + jit_executor = JitExecutor( + self, + engine, + capi_func, + module, + args_spec, + function_name, + jit_time_profiling=self.envar.jitTimeProfiling, + ) + jit_executor = jit_executor.update_jit_cuda_modules(self.kernel_symbols) + + if not no_cache: + # module stored in cache is compiled. + self.jit_cache[module_hash] = jit_executor + + return jit_executor + + def post_compilation_cleanup(self): + """Clean up some internal state after one compilation is completed.""" + # clear the kernel symbols after the compilation is done. + self.kernel_symbols = [] + self.launch_inner_count = 0 + # reset num_kernels to 0 for next compilation. + self.num_kernels = 0 + # reset the compile options after the compilation is done. + self.compile_options = CompileOptions() + + def generate_mlir( + self, + funcBody, + kwargs, + function_name, + gpu_module_attrs, + args, + args_spec, + pipeline, + no_cache, + compile_only, + loc=None, + ): + """Generate MLIR module and compile iself.T_provider.""" + with ir.Context(), ir.Location.unknown(): + # Convert input arguments to MLIR arguments + exe_args, func_types, adapted_args = self.generate_mlir_function_types( + funcBody, function_name, args, kwargs, args_spec + ) + + # Generate original ir module and its hash value. + module, module_hash, result = self.generate_original_ir( + ir, + func, + funcBody, + kwargs, + function_name, + func_types, + gpu_module_attrs, + args, + args_spec, + ) + + # dryrun is used to only generate IR + if self.envar.dryrun: + return result + + if ( + no_cache + or module_hash not in self.jit_cache + or self.jit_cache[module_hash].capi_func is None + ): + # no cache or cache miss, do ir generation/compilation/jit engine + jit_executor = self.compile_and_cache( + module, module_hash, function_name, pipeline, args_spec, no_cache + ) + else: + # cache hit + log().info( + "JIT cache hit IN-MEMORY function=[%s] module_hash=[%s]", + function_name, + module_hash, + ) + jit_executor = self.jit_cache[module_hash] + + self.post_compilation_cleanup() + # If compile_only is set, bypass execution return the jit_executor directly + if compile_only: + return jit_executor + # Run the compiled program + jit_executor.run_compiled_program(exe_args) + + return result + + def run_preprocessor(self, funcBody): + if not hasattr(funcBody, "_preprocessed"): + function_name = funcBody.__name__ + self.funcBody = funcBody + log().info("Started preprocessing [%s]", function_name) + exec_globals = self._get_globals() + transformed_ast = self.preprocessor.transform(funcBody, exec_globals) + if self.envar.print_after_preprocessor: + log().info( + f"# Printing unparsed AST after preprocess of func=`{function_name}` id=`{id(funcBody)}`" + ) + DSLPreprocessor.print_ast(transformed_ast) + funcBody._preprocessed = True + return transformed_ast + return None + + def get_function_ptr(self, original_function): + file_name = inspect.getsourcefile(original_function) + code_object = compile( + original_function._transformed_ast, filename=file_name, mode="exec" + ) + return self.preprocessor.exec( + original_function.__name__, + original_function, + code_object, + self._get_globals(), + ) + + def _get_function_bound_args(self, sig, func_name, *args, **kwargs): + """ + Binds provided arguments to a function's signature and applies default values. + + E.g. given a function signature `def foo(a, b=2, c=3)`, and at call-site if we do + `foo(a=1, c=4)`, the returned BoundArguments object will have args = `[1]` + and kwargs = `{'b': 2, 'c': 4}` + + An exception will be raised if binding fails. + """ + try: + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + except Exception as e: + raise DSLRuntimeError( + f"Failed to bind arguments to function `{func_name}` with signature `{sig}`", + cause=e, + ) + return bound_args + + def _canonicalize_args(self, sig, *args, **kwargs): + """ + Canonicalize the input arguments so that returned args only contain + positional arguments and kwargs only contain keyword arguments. + """ + function_name = self.funcBody.__name__ + bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) + canonicalized_args = bound_args.args + canonicalized_kwargs = bound_args.kwargs + return canonicalized_args, canonicalized_kwargs + + def _check_arg_count(self, *args, **kwargs): + if not self.funcBody: + raise DSLRuntimeError("Function body is not set.") + + # Pass the actual function object to inspect.signature to get the signature. + sig = inspect.signature(self.funcBody) + + function_name = self.funcBody.__name__ + + bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) + + # Check if all non-default arguments are provided + for param in sig.parameters.values(): + if ( + param.default is inspect.Parameter.empty + and param.name not in bound_args.arguments + ): + raise DSLRuntimeError( + f"Missing required argument in `{function_name}`: '{param.name}'" + ) + + return sig + + def _func(self, funcBody, *args, **kwargs): + """Decorator for MLIR functions. + It cuts the boilerplate code, does the following: + 1. Generates `func.func` + 2. Types translation (numpy arrays -> cute.memref, float -> , etc.) + 3. Compiles and JITs the MLIR module + 4. Invokes the generated function + 5. Operator overloading (a + b --> arith.addi a, b) + 6. Generates GPU kernel function with GPU module and kernel attributes baked + """ + if ir.Context.current is None: + pass + elif ir.InsertionPoint.current is not None: + return funcBody(*args, **kwargs) + + function_name = funcBody.__name__ + self.funcBody = funcBody + + pipeline = kwargs.pop("pipeline", None) + gpu_module_attrs = kwargs.pop("gpu_module_attrs", {}) + + # Disable cache + no_cache = kwargs.pop("no_cache", False) + + # Always compile(disable cache) and return the result jit_executor + compile_only = kwargs.pop("compile_only", False) + + if not no_cache and compile_only: + no_cache = True + self.print_warning("Cache is disabled as user wants to compile only.") + + # Check the number of arguments + sig = self._check_arg_count(*args, **kwargs) + + args_spec = inspect.getfullargspec(funcBody) + + # Canonicalize the input arguments + canonicalized_args, canonicalized_kwargs = self._canonicalize_args( + sig, *args, **kwargs + ) + + # Simple name mangling + function_name = self.mangle_name(function_name, canonicalized_args, args_spec) + + # Generate MLIR Context and start generating IR + log().debug(f"Generating MLIR for function '{function_name}'") + result = self.generate_mlir( + funcBody, + canonicalized_kwargs, + function_name, + gpu_module_attrs, + canonicalized_args, + args_spec, + pipeline, + no_cache, + compile_only, + ) + + return result + + class _KernelGenHelper(ABC): + def __init__(self): + self.func_op = None + self.func_type = None + + @abstractmethod + def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None): + assert arg_types is not None, "Invalid arg_types!" + assert kernel_name is not None, "kernel name is empty" + pass + + @abstractmethod + def generate_func_ret_op(self): + pass + + @abstractmethod + def generate_launch_op(self, *args, **kwargs): + pass + + @abstractmethod + def get_func_body_start(self): + pass + + @abstractmethod + def enter_gpu_module(module): + """Compute the insertion point into the given module.""" + pass + + @lru_cache(maxsize=1) + def _get_default_stream(self): + """Returns the default stream 0""" + from .runtime import cuda as cuda_helpers + + return cuda_helpers.stream_create() + + def _execute_cuda( + self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None + ): + """ + Executes a specified CUDA kernel from a cubin file, handling module loading, + kernel retrieval, stream creation, kernel launch, and synchronization. + """ + from .runtime import cuda as cuda_helpers + + # Step 1. Load CUDA Module + module = cuda_helpers.load_cubin_module(fname_cubin) + # Step 2. Find CUDA function + kernel_ptr = cuda_helpers.get_kernel_function(module, kernel_name) + + sync_execution_default = False + if stream is None: + stream = self._get_default_stream() + sync_execution_default = True + + # Step 4. Launch the kernel + cuda_helpers.launch_kernel( + kernel_ptr, + grid_size, + block_size, + stream, + smem_size=smem_size, + kernel_args=self.exe_args, + ) + + if sync_execution_default: + # Step 5. Optional Sync cuda stream + cuda_helpers.stream_sync(stream) + + def _execute_by_cuda_driver( + self, + kernel_generator, + generate_cubin, + grid_size, + block_size, + smem_size, + stream=None, + ): + """ + This function builds IR and execute the module using cuda driver. + It doesn't use mlir's cuda runtime + """ + ret = None + + # Step 1. Build IR + with ir.Context(), ir.Location.unknown(): + loc = self.get_location() + module = ir.Module.create(loc=loc) + unit_attr = ir.UnitAttr.get() + module.operation.attributes["gpu.container_module"] = unit_attr + with ir.InsertionPoint(module.body): + self._build_gpu_module() + ret, kernel_name = kernel_generator() + log().debug( + f"Kernel generator returned: ret={ret}, kernel_name={kernel_name}" + ) + + module = self.build_module(module, kernel_name) + + # dryrun is used to only generate IR + if self.envar.dryrun: + return ret + + # Generate cubin + fname_cubin = generate_cubin(module, kernel_name) + + # Execute a cuda kernel from cubin + self._execute_cuda( + fname_cubin, kernel_name, grid_size, block_size, smem_size, stream + ) + + return ret + + def generate_kernel_operands_and_types( + self, kernel_func, kernel_name, args_spec, args, kwargs + ): + """ + Generate the operands and types for the kernel function + """ + + kernel_operands, kernel_arg_types, kernel_arg_attrs = [], [], [] + + log().debug( + "Processing GPU kernel call in [%s] mode", + ( + f"Only {self.device_jit_decorator_name}" + if self.device_compilation_only + else f"{self.host_jit_decorator_name} + {self.device_jit_decorator_name}" + ), + ) + + if self.device_compilation_only: + return kernel_operands, kernel_arg_types, kernel_arg_attrs + + kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = ( + self._generate_jit_func_args( + kernel_func, kernel_name, args, kwargs, args_spec, is_host=False + ) + ) + + log().debug("Final kernel_operands: %s", ", ".join(map(str, kernel_operands))) + log().debug("Final kernel_arg_types: %s", ", ".join(map(str, kernel_arg_types))) + log().debug("Final kernel_arg_attrs: %s", ", ".join(map(str, kernel_arg_attrs))) + + assert ( + len(kernel_operands) == len(kernel_arg_types) == len(kernel_arg_attrs) + ), "Size of kernel_operands, kernel_arg_types and kernel_arg_attrs must be equal" + + return kernel_operands, kernel_arg_types, kernel_arg_attrs + + def kernel_launcher(self, *dargs, **dkwargs): + def decorator(funcBody): + @wraps(funcBody) + def kernel_wrapper(*args, **kwargs): + """ + Base decorator for generating kernel function + + This decorator provides a template for kernel function generation + including kernel function header/body and kernel launch op at call site + + Optional arguments (with default value in <>): + - requiredArgs <[]>: specifies the mandatory arguments that must present in kernel function signature + the args will be validated and collected as a namedtuple + - optionalArgs <[]>: specifies the optional arguments that might present in kernel function signature + the args will be collected (if present) as a namedtuple + - unitAttrNames <[]>: specifies the name(s) of ir.UnitAttr to be set for kernel function op + - valueAttrDict <{}>: specifies the name(s) and value(s) of ir.Attribute to be set for kernel function op + - kernelGenHelper : specifies the mandatory customized kernel generation helper class (derived from _KernelGenHelper) + + Return value: + A namedtuple "KernelReturns" is returned with following fields: + - kernel_func_ret: the return of the kernel function + - launch_op_ret: the return of the launch op + """ + + requiredArgs = dkwargs.get("requiredArgs", []) + optionalArgs = dkwargs.get("optionalArgs", []) + unitAttrNames = dkwargs.get("unitAttrNames", []) + valueAttrDict = dkwargs.get("valueAttrDict", {}) + kernelGenHelper = dkwargs.get("kernelGenHelper", None) + + kernel_name = funcBody.__name__ + args_spec = inspect.getfullargspec(funcBody) + self.funcBody = funcBody + + # Give each kernel a unique name. (The same kernel may be + # called multiple times, resulting in multiple kernel traces.) + # The mangled name of Python function is part of the name to + # improve readability. + kernel_name = f"kernel_{self.mangle_name(kernel_name, args, args_spec)}_{self.num_kernels}" + self.num_kernels += 1 + + # Step 0. Preprocess the arguments + def extract_args(argNames, assertIfNone=False) -> list: + extracted = [] + for name in argNames: + value = kwargs.pop(name, None) + if assertIfNone and value is None: + raise DSLRuntimeError( + f"{name} is required for {kernel_name}" + ) + extracted.append(value) + + return extracted + + RequiredArgs = namedtuple("RequiredArgs", requiredArgs) + req_args = ( + RequiredArgs._make(extract_args(requiredArgs, assertIfNone=True)) + if requiredArgs + else None + ) + OptionalArgs = namedtuple("OptionalArgs", optionalArgs) + opt_args = ( + OptionalArgs._make(extract_args(optionalArgs)) + if optionalArgs + else None + ) + assert ( + kernelGenHelper is not None + ), "kernelGenHelper should be explicitly specified!" + + # check arguments + sig = self._check_arg_count(*args, **kwargs) + + # Canonicalize the input arguments + canonicalized_args, canonicalized_kwargs = self._canonicalize_args( + sig, *args, **kwargs + ) + + kernel_operands, kernel_types, kernel_arg_attrs = ( + self.generate_kernel_operands_and_types( + funcBody, + kernel_name, + args_spec, + canonicalized_args, + canonicalized_kwargs, + ) + ) + + with self._enter_gpu_module(): + log().debug("Generating device kernel") + if self.device_compilation_only: + log().debug("Generating cuda-python arguments") + # Convert input arguments to MLIR arguments + self.exe_args, kernel_types, _ = ( + self.generate_mlir_function_types( + funcBody, + kernel_name, + canonicalized_args, + canonicalized_kwargs, + args_spec, + ) + ) + + helper = kernelGenHelper() + loc = self.get_location() + fop = helper.generate_func_op( + kernel_types, kernel_arg_attrs, kernel_name, loc + ) + log().debug(f"Kernel function op: {fop}") + for attr in unitAttrNames: + fop.attributes[attr] = ir.UnitAttr.get() + for key, val in valueAttrDict.items(): + fop.attributes[key] = val + + fop.sym_visibility = ir.StringAttr.get("public") + with ir.InsertionPoint(helper.get_func_body_start()): + ir_args, ir_kwargs = self.generate_execution_arguments( + canonicalized_args, canonicalized_kwargs, fop, args_spec + ) + log().debug( + f"IR arguments - args: {ir_args} ; kwargs: {ir_kwargs}" + ) + # Call user function body + kernel_ret = funcBody(*ir_args, **ir_kwargs) + helper.generate_func_ret_op() + + # Step 3. Generate call site `launch_func` + kernel_sym = ir.SymbolRefAttr.get(["kernels", kernel_name]) + launch_ret = helper.generate_launch_op( + kernelSym=kernel_sym, + kernelOperands=kernel_operands, + requiredArgs=req_args, + optionalArgs=opt_args, + ) + + KernelReturns = namedtuple( + "KernelReturns", ["kernel_func_ret", "launch_op_ret"] + ) + result = KernelReturns( + kernel_func_ret=kernel_ret, launch_op_ret=launch_ret + ) + log().debug(f"Kernel result: {result}, kernel name: {kernel_name}") + return result, kernel_name + + return kernel_wrapper + + if len(dargs) == 1 and callable(dargs[0]): + return decorator(dargs[0]) + else: + return decorator diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fa683477f3fb5b18f5459e19bdd468432590b952 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides utilities for the environment variables setup. + +It provides an EnvironmentVarManager, which reads environment variables for the DSL +and caches them for efficient access. + +It also provides utilities to automatically setup a subset of environment variables +based on heuristics. +""" + +import os +import sys +import shutil +import glob +from pathlib import Path +from functools import lru_cache +from typing import Any + +from ..base_dsl.runtime.cuda import get_compute_capability_major_minor +from .utils.logger import log + +IS_WINDOWS = sys.platform == "win32" +CLIB_EXT = ".dll" if IS_WINDOWS else ".so" + +# ============================================================================= +# Environment Variable Helpers +# ============================================================================= + + +@lru_cache(maxsize=None) +def get_str_env_var(var_name, default_value=None): + value = os.getenv(var_name) + return value if value is not None else default_value + + +@lru_cache(maxsize=None) +def get_bool_env_var(var_name, default_value=False): + value = get_str_env_var(var_name) + if value is None: + return default_value + return value not in {"False", "0", ""} + + +@lru_cache(maxsize=None) +def get_int_env_var(var_name, default_value=0): + value = get_str_env_var(var_name) + return int(value) if value and value.isdigit() else default_value + + +@lru_cache(maxsize=None) +def has_env_var(var_name): + return os.getenv(var_name) is not None + + +def detect_gpu_arch(prefix): + """ + Attempts to detect the machine's GPU architecture. + + Returns: + A string representing the GPU architecture (e.g. "70" for compute capability 7.0), + or a default value(e.g. "sm_100") if the GPU architecture cannot be determined. + """ + arch = (None, None) + try: + arch = get_compute_capability_major_minor() + except Exception as e: + log().info(f"Failed to get CUDA compute capability: {e}") + + if arch == (None, None): + # default to sm_100 + arch = (10, 0) + + major, minor = arch + suffix = "" + if major >= 9: + suffix = "a" + + return f"sm_{major}{minor}{suffix}" + + +def find_libs_in_ancestors(start, target_libs, lib_folder_guesses): + """ + Search ancestor directories for a candidate library folder containing all required libraries. + + Starting from the given path, this function traverses up through each parent directory. + For every ancestor, it checks candidate subdirectories (specified by lib_folder_guesses) + for files that match the required library extension (CLIB_EXT). Library file names are + canonicalized by removing the "lib" prefix from their stem. If a candidate directory contains + all of the required libraries (as specified in target_libs), the function returns a list of + absolute paths to these library files. + + Parameters: + start (str or Path): The starting directory from which to begin the search. + target_libs (iterable of str): A collection of required library names (without the "lib" prefix). + lib_folder_guesses (iterable of str): Relative paths from an ancestor directory that may contain the libraries. + + Returns: + list[str] or None: A list of resolved paths to the required library files if found; otherwise, None. + """ + # Traverse through all parent directories of the resolved starting path. + for ancestor in Path(start).resolve().parents: + # Iterate over each candidate relative directory path. + for rel_path in lib_folder_guesses: + target_dir = ancestor / rel_path + # Skip if the candidate directory does not exist. + if not target_dir.is_dir(): + continue + + # Initialize a list to hold the resolved paths of matching library files. + libs_cand = [] + # Create a set of the remaining libraries we need to find. + remaining_libs = set(target_libs) + + # Iterate over all items in the candidate directory. + for p in target_dir.iterdir(): + # Consider only files with the expected library extension. + if p.suffix == CLIB_EXT: + # Canonicalize the library name by removing the "lib" prefix. + lib_name = p.stem.removeprefix("lib") + # If this library is required, add its resolved path and mark it as found. + if lib_name in remaining_libs: + libs_cand.append(str(p.resolve())) + remaining_libs.remove(lib_name) + + # If all required libraries have been found, return the list of library paths. + if len(remaining_libs) == 0: + return libs_cand + + # Return None if no candidate directory contains all required libraries. + return None + + +def _find_cuda_home(): + """Find the CUDA installation path using a series of heuristic methods. + Methods below are checked in order, and the function returns on first match: + 1. Checking the environment variables CUDA_HOME and CUDA_PATH. + 2. Searching for the 'nvcc' compiler in the system PATH and deriving the path of cuda. + 3. Scanning common installation directories based on the operating system. + - On Windows systems (when IS_WINDOWS is True), it searches in: + C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.* + - On Unix-like systems, it searches in: + /usr/local/cuda* + + Returns: + Optional[str]: The absolute CUDA installation path if found; otherwise, None. + + Note: + The variable IS_WINDOWS is defined in the module scope. + """ + # Guess #1 + cuda_home = get_str_env_var("CUDA_HOME") or get_str_env_var("CUDA_PATH") + if cuda_home is None: + # Guess #2 + nvcc_path = shutil.which("nvcc") + if nvcc_path is not None: + cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) + else: + # Guess #3 + if IS_WINDOWS: + glob_pat = "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*" + else: + glob_pat = "/usr/local/cuda*" + cuda_homes = glob.glob(glob_pat) + if len(cuda_homes) == 0: + cuda_home = "" + else: + cuda_home = cuda_homes[0] + if not os.path.exists(cuda_home): + cuda_home = None + return cuda_home + + +def get_cuda_toolkit_path(): + """ + Get cuda_toolkit_path. It returns get_str_env_var('CUDA_TOOLKIT_PATH') if + set. Otherwise, attempts to discover a valid CUDA toolkit location and + return. If not found, return None. + """ + # Check if the environment variable is already set, if so, return it immediately. + try: + cuda_toolkit_path_existing = get_str_env_var("CUDA_TOOLKIT_PATH") + if cuda_toolkit_path_existing: + return cuda_toolkit_path_existing + + found_cuda_home = _find_cuda_home() + if found_cuda_home: + return found_cuda_home + except Exception as e: + log().info("default_env: exception on get_cuda_toolkit_path", e) + return None + + +def get_prefix_dsl_libs(prefix: str): + """ + Returns get_str_env_var('{prefix}_LIBS') if set. + Otherwise, attempts to discover libs based on heuristics and return + If not found, return None. + """ + # Check if the environment variable is already set, if so, return it immediately. + try: + prefix_libs_existing = get_str_env_var(f"{prefix}_LIBS") + if prefix_libs_existing: + return prefix_libs_existing + + def get_libs_cand(start): + target_libs = { + "mlir_c_runner_utils", + "mlir_runner_utils", + "mlir_cuda_runtime", + } + lib_folder_guesses = [ + "lib", + ] + + libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses) + if libs_cand: + dsl_libs = ":".join(libs_cand) + return dsl_libs + + return None + + # find from install folder + dsl_libs = get_libs_cand(__file__) + + if not dsl_libs: + # try to find from build folder structure + dsl_libs = get_libs_cand(Path(__file__).parent.parent.resolve()) + + return dsl_libs + + except Exception as e: + log().info(f"default_env: exception on get_prefix_dsl_libs", e) + return None + + +class EnvironmentVarManager: + """Manages environment variables for configuration options. + + Printing options: + - [DSL_NAME]_LOG_TO_CONSOLE: Print logging to stderr (default: False) + - [DSL_NAME]_PRINT_AFTER_PREPROCESSOR: Print after preprocess (default: False) + - [DSL_NAME]_PRINT_IR: Print generated IR (default: False) + - [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True) + File options: + - [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False) + - [DSL_NAME]_LOG_TO_FILE: Store all logging into a file, excluding COMPILE_LOGS (default: False) + Other options: + - [DSL_NAME]_LOG_LEVEL: Logging level to set, for LOG_TO_CONSOLE or LOG_TO_FILE (default: 1). + - [DSL_NAME]_DRYRUN: Generates IR only (default: False) + - [DSL_NAME]_ARCH: GPU architecture (default: "sm_100") + - [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False) + - [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False) + - [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False) + - [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False) + - [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False) + - [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000) + - [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None) + - [DSL_NAME]_NO_SOURCE_LOCATION: Generate source location (default: False) + """ + + def __init__(self, prefix="DSL"): + self.prefix = prefix # change if needed + + # Printing options + self.print_after_preprocessor = get_bool_env_var( + f"{prefix}_PRINT_AFTER_PREPROCESSOR", False + ) + self.printIR = get_bool_env_var(f"{prefix}_PRINT_IR", False) + self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True) + # File options + self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False) + # Logging options + self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False) + self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False) + if ( + has_env_var(f"{prefix}_LOG_LEVEL") + and not self.log_to_console + and not self.log_to_file + ): + log().warning( + f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!" + ) + self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1) + + # Other options + self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False) + self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix)) + self.warnings_as_errors = get_bool_env_var( + f"{prefix}_WARNINGS_AS_ERRORS", False + ) + self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False) + self.enable_optimization_warnings = get_bool_env_var( + f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False + ) + self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False) + self.disable_file_caching = get_bool_env_var( + f"{prefix}_DISABLE_FILE_CACHING", False + ) + self.file_caching_capacity = get_int_env_var( + f"{prefix}_FILE_CACHING_CAPACITY", 1000 + ) + self.generate_source_location = not get_bool_env_var( + f"{prefix}_NO_SOURCE_LOCATION", False + ) + # set cuda + self.cuda_toolkit = get_cuda_toolkit_path() + + # set mlir shared libraries + self.shared_libs = get_prefix_dsl_libs(prefix) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..83268009c85ef64967d6a81ab886ebeb704f140d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py @@ -0,0 +1,357 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides jit executor related classes +""" +import ctypes +import inspect +import io +from typing import get_origin + +import numpy as np + +# MLIR modules imports +from .._mlir import ir + +# Local modules imports +from . import typing as t +from .common import DSLRuntimeError +from .runtime import cuda as cuda_helpers +from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr +from .typing import get_c_pointers +from .utils.logger import log +from .utils.timer import timer + + +class CudaSingleModule: + def __init__(self, cuda_module, kernel_ptr): + self.cuda_module = cuda_module + self.kernel_ptr = kernel_ptr + + +class CudaModules: + def __init__(self, modules, args): + # list of CudaSingleModule + self.modules = modules + # extra kernel ptr arguments for launch + self.args = args + + +class JitExecutor: + def __init__( + self, + dsl, + engine, + capi_func, + ir_module, + args_spec, + function_name, + cuda_modules: CudaModules = None, + jit_time_profiling=False, + ): + self.dsl = dsl + self.engine = engine + self.capi_func = capi_func + self.ir_module = ir_module + self.args_spec = args_spec + self.function_name = function_name + if args_spec is not None: + self.original_args_spec = args_spec + self.args_spec = self.filter_runtime_arg_spec(args_spec) + # cuda kernels + self.cuda_modules = cuda_modules + self.jit_time_profiling = jit_time_profiling + + def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec): + runtime_args = [] + runtime_annotations = {} + runtime_defaults = [] + + # Calculate the offset where defaults start in the original args + if arg_spec.defaults: + defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults) + else: + defaults_start_idx = len(arg_spec.args) + + # Filter arguments and maintain their properties + for i, arg_name in enumerate(arg_spec.args): + arg_type = arg_spec.annotations.get(arg_name, None) + + # Skip compile-time arguments + if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name): + continue + + # Keep runtime arguments + runtime_args.append(arg_name) + if arg_name in arg_spec.annotations: + runtime_annotations[arg_name] = arg_type + + # Keep corresponding default if it exists + if i >= defaults_start_idx: + default_idx = i - defaults_start_idx + runtime_defaults.append(arg_spec.defaults[default_idx]) + + # Filter kwonlyargs and their defaults + runtime_kwonlyargs = [] + runtime_kwonlydefaults = {} + + if arg_spec.kwonlyargs: + for kwarg in arg_spec.kwonlyargs: + arg_type = arg_spec.annotations.get(kwarg, None) + + # Apply same filtering logic + if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name): + continue + + runtime_kwonlyargs.append(kwarg) + if kwarg in arg_spec.annotations: + runtime_annotations[kwarg] = arg_type + if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults: + runtime_kwonlydefaults[kwarg] = arg_spec.kwonlydefaults[kwarg] + + # Convert runtime_defaults to tuple if not empty (as expected by FullArgSpec) + runtime_defaults = tuple(runtime_defaults) if runtime_defaults else None + + return inspect.FullArgSpec( + args=runtime_args, + varargs=arg_spec.varargs, # Keep original varargs + varkw=arg_spec.varkw, # Keep original varkw + defaults=runtime_defaults, + kwonlyargs=runtime_kwonlyargs, + kwonlydefaults=runtime_kwonlydefaults if runtime_kwonlydefaults else None, + annotations=runtime_annotations, + ) + + def __del__(self): + if self.cuda_modules: + cuda_modules = [module.cuda_module for module in self.cuda_modules.modules] + for module in set(cuda_modules): + cuda_helpers.unload_cubin_module(module) + + def get_constexpr_args(self) -> list[dict[str, int | str]]: + """ + This function returns the constexpr args that have been pruned from the original function signature. + The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). + + :return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). + :rtype: list[dict[str, int | str]] + """ + if self.original_args_spec is None: + return list() + constexpr_args = list() + for i, arg_name in enumerate(self.original_args_spec.args): + if arg_name not in self.args_spec.args: + constexpr_args.append({"argument_index": i, "argument_name": arg_name}) + + if self.original_args_spec.kwonlyargs: + for kwarg in self.original_args_spec.kwonlyargs: + if kwarg not in self.args_spec.kwonlyargs: + constexpr_args.append( + {"argument_index": None, "argument_name": kwarg} + ) + return constexpr_args + + def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec): + """ + This function is the prune version of `generate_mlir_function_types` which only generates execution args + to get rid of mlir context. + """ + + # Process positional arguments with defaults + rectified_args = list(args) + if args_spec.defaults and len(args) < len(args_spec.args): + rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :]) + for k, v in kwargs.items(): + if k in args_spec.args: + idx = args_spec.args.index(k) + if idx < len(rectified_args): + rectified_args[idx] = v + else: + rectified_args.append(v) + + # Process keyword arguments + rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args} + if args_spec.kwonlydefaults and len(rectified_kwargs) < len( + args_spec.kwonlyargs + ): + rectified_kwargs.update(args_spec.kwonlydefaults) + + # args/kwargs must match arg_specs + if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len( + args_spec.kwonlyargs + ): + raise DSLRuntimeError( + "input args/kwargs length does not match runtime function signature!", + context={ + "input args length": len(rectified_args), + "input kwargs length": len(rectified_kwargs), + "function signature args length": len(args_spec.args), + "function signature kwonlyargs length": len(args_spec.kwonlyargs), + }, + ) + + exe_args = [] + adapted_args = [] + input_args = rectified_args + list(rectified_kwargs.values()) + input_arg_names = args_spec.args + args_spec.kwonlyargs + for arg, arg_name in zip(input_args, input_arg_names): + # short-cut for args already converted + if hasattr(arg, "__c_pointers__"): + exe_args.extend(arg.__c_pointers__()) + continue + + arg_type = args_spec.annotations.get(arg_name, None) + + # Implicit cast to NumericMeta + if isinstance(arg_type, t.NumericMeta): + arg = t.cast(arg, arg_type) + else: + # If not any known type, try registered adapter to do the conversion + adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) + if adapter: + arg = adapter(arg) + adapted_args.append(arg) + + exe_args.extend(get_c_pointers(arg)) + + return exe_args, adapted_args + + def __call__(self, *args, **kwargs): + exe_args, adapted_args = self.generate_execution_args( + args, kwargs, self.args_spec + ) + + self.run_compiled_program(exe_args) + + # Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`. + def get_invoke_packed_args(self, exe_args): + if self.cuda_modules: + exe_args += self.cuda_modules.args + packed_args = (ctypes.c_void_p * len(exe_args))() + for argNum in range(len(exe_args)): + packed_args[argNum] = exe_args[argNum] + return packed_args + + def run_compiled_program(self, exe_args): + if self.jit_time_profiling: + profiler = timer(enable=True) + try: + packed_args = profiler(self.get_invoke_packed_args)(exe_args) + profiler(self.capi_func)(packed_args) + except Exception as e: + raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e) + else: + try: + packed_args = self.get_invoke_packed_args(exe_args) + self.capi_func(packed_args) + except Exception as e: + raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e) + + def update_jit_cuda_modules(self, kernel_symbols): + # preload cuda module from compiled cubin in ir and store to jit_executor.kernels. + if len(kernel_symbols) > 0: + extra_args = [] + module = self.ir_module + cuda_kernel_cache = dict() + cuda_driver_version = cuda_helpers.get_driver_version() + for sym in kernel_symbols: + if sym not in cuda_kernel_cache: + log().debug(f"Loading CUDA module for symbol: {sym}") + + # load cuda module/get function pointer from module and cache + def walk_callback(sym, func_sym, cubin_data): + cubin_module = cuda_helpers.load_cubin_module_data(cubin_data) + kernel_ptr = cuda_helpers.get_kernel_function( + cubin_module, func_sym + ) + # Enable non-portable cluster size for CUDA version 11.8 or higher. + if cuda_driver_version >= 11080: + cuda_helpers.set_kernel_attribute( + kernel_ptr, + cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, + 1, + ) + cuda_kernel_cache[sym] = CudaSingleModule( + cubin_module, kernel_ptr + ) + + self.walk_module_and_get_cubin_data(module, sym, walk_callback) + else: + log().debug(f"Symbol {sym} already in cache") + # check if kernel is empty. + if sym in cuda_kernel_cache: + extra_args.append( + ctypes.c_void_p(cuda_kernel_cache[sym].kernel_ptr.getPtr()) + ) + # store to the jit result if jit result is cached. + self.cuda_modules = CudaModules(cuda_kernel_cache.values(), extra_args) + + return self + + def _get_escaped_cubin_bytes(self, cubin_data): + """This function escapes cubin data from mlir raw bytecode to executable binary bytes""" + + def ishex(inp): + return ( + inp in range(0x30, 0x3A) + or inp in range(0x61, 0x67) + or inp in range(0x41, 0x47) + ) + + converted = bytearray() + idx = 0 + while idx < len(cubin_data): + # escape the original bytes + if cubin_data[idx] == 0x5C: + # if data of idx is b'\\' + if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]): + converted += bytearray.fromhex( + cubin_data[idx + 1 : idx + 3].decode() + ) + idx += 3 + elif cubin_data[idx + 1] == 0x5C: + converted.append(cubin_data[idx]) + idx += 2 + else: + # no escape, directly write + converted.append(cubin_data[idx]) + idx += 1 + return bytes(converted) + + def walk_module_and_get_cubin_data(self, module, sym, callback): + """This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback.""" + + def walk_gpu_binary_op(op): + if op.name != "gpu.binary": + return ir.WalkResult.ADVANCE + s = io.BytesIO() + op.write_bytecode(s) + cubin_data = s.getvalue() + if sym.encode() not in cubin_data: + return ir.WalkResult.ADVANCE + + if ( + "kernels" != op.opview.sym_name.value + and sym != op.opview.sym_name.value + ): + return ir.WalkResult.ADVANCE + # function symbol of kernel(gpu.launch_func) is equal to sym name in mlir + func_sym = sym + if sym == op.opview.sym_name.value and not sym.endswith("_kernel"): + func_sym = sym.rsplit("_", 1)[0] + + cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0] + cubin_data = self._get_escaped_cubin_bytes(cubin_data) + callback(sym, func_sym, cubin_data) + return ir.WalkResult.ADVANCE + + module.operation.walk(walk_gpu_binary_op) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc475fdda59450f07c35ae244d6223446470c6d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides a runtime utility functions that are needed for +the DSL. +""" + +from . import dlpack_types +from . import cuda +from . import jit_arg_adapters + +__all__ = [ + "dlpack_types", + "cuda", + "jit_arg_adapters", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..97ae778c0cd5ae19d20fac8e045e2021832f5bbc --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py @@ -0,0 +1,476 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides CUDA Python helper functions +""" + + +from functools import lru_cache +from dataclasses import dataclass +from typing import List, Optional +import numpy as np +import os +import ctypes + +import cuda.bindings.driver as cuda +import cuda.bindings.nvrtc as nvrtc + +# MLIR imports +from ..._mlir import ir +from ..._mlir.dialects import gpu + +# Local module imports +from ..utils.logger import log as _log +from ..common import * +from .jit_arg_adapters import JitArgAdapterRegistry + + +# ============================================================================= +# Utils +# ============================================================================= + + +def _cudaGetErrorEnum(error): + if isinstance(error, cuda.CUresult): + err, name = cuda.cuGetErrorName(error) + return name if err == cuda.CUresult.CUDA_SUCCESS else "" + elif isinstance(error, nvrtc.nvrtcResult): + return nvrtc.nvrtcGetErrorString(error)[1] + else: + raise DSLRuntimeError("Unknown error type: {}".format(error)) + + +def _get_gpu_arch_info(major, minor): + """Get GPU architecture information and compatibility details.""" + gpu_arch_map = { + (7, 0): ("Volta", "sm_70", ["sm_70"]), # V100 + (7, 5): ("Turing", "sm_75", ["sm_75"]), # RTX 20 Series, Quadro RTX + (8, 0): ("Ampere", "sm_80", ["sm_80"]), # A100 + (8, 6): ("Ampere", "sm_86", ["sm_86", "sm_80"]), # RTX 30 Series + (8, 9): ("Ada", "sm_89", ["sm_89", "sm_86"]), # RTX 40 Series + (8, 7): ("Ampere", "sm_87", ["sm_87", "sm_86", "sm_80"]), # A10, A40 + (9, 0): ("Hopper", "sm_90a", ["sm_90a"]), # H100 + (10, 0): ("Blackwell", "sm_100a", ["sm_100a"]), # B200 + } + return gpu_arch_map.get( + (major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"]) + ) + + +def get_compute_capability_major_minor(device_id: int = 0): + """ + Returns the compute capability of the CUDA device as a tuple of (major, minor). + For example: (8, 0) for Ampere, (9, 0) for Hopper, (10, 0) for Blackwell. + Returns None on failure. + """ + try: + checkCudaErrors(cuda.cuInit(0)) + device = checkCudaErrors(cuda.cuDeviceGet(device_id)) + major = checkCudaErrors( + cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, + device, + ) + ) + minor = checkCudaErrors( + cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, + device, + ) + ) + return major, minor + except RuntimeError as e: + _log().info(f"Failed to get CUDA compute capability: {e}") + return None, None + + +@dataclass +class DeviceInfo: + """Data class to store CUDA device information.""" + + device_count: int = 0 + current_device: int = 0 + device_name: Optional[str] = None + major_version: Optional[int] = None + minor_version: Optional[int] = None + arch_name: Optional[str] = None + sm_arch: Optional[str] = None + compatible_archs: Optional[List[str]] = None + memory_gb: Optional[float] = None + target_arch: Optional[str] = None + error_message: Optional[str] = None + initialization_failed: bool = False + + def pretty_str(self) -> str: + """ + Convert DeviceInfo to a formatted string for display. + """ + info = "" + + if self.initialization_failed: + return f"{Colors.BOLD}- CUDA initialization failed{Colors.RESET}" + + if self.error_message: + return f"{Colors.BOLD}- Failed to get GPU info: {self.error_message}{Colors.RESET}" + + if self.device_count > 0: + info += f"{Colors.BOLD}- CUDA devices available: {self.device_count} (current: {self.current_device})\n" + + if self.major_version is not None and self.minor_version is not None: + info += f"- Architecture: {Colors.BLUE}{self.arch_name}{Colors.RESET} ({Colors.GREEN}{self.sm_arch}{Colors.RESET})\n" + info += f"- Compatible SM archs: {Colors.GREEN}{', '.join(self.compatible_archs or [])}{Colors.RESET}\n" + + if self.memory_gb is not None: + info += f"- Total Memory: {Colors.BLUE}{self.memory_gb:.2f} GB{Colors.RESET}\n" + + else: + info += f"- Compute capability: unknown\n" + info += f"- SM arch: unknown{Colors.RESET}\n" + else: + info += f"- No devices available\n" + + return info + + +def get_device_info() -> DeviceInfo: + """ + Get detailed information about CUDA devices. + Returns a DeviceInfo dataclass with device information. + """ + device_info = DeviceInfo() + + # Initialize CUDA if not already initialized + try: + result = cuda.cuInit(0) + if result[0].value: # Check for error + device_info.initialization_failed = True + return device_info + except: + pass + + try: + # Get device count + result = cuda.cuDeviceGetCount() + device_info.device_count = result[1] if result[0].value == 0 else 0 + + if device_info.device_count > 0: + # Get current device + try: + result = cuda.cuCtxGetDevice() + if result[0].value == 0: + device_info.current_device = result[1] + except: + pass + + # Get device name + try: + name_result = cuda.cuDeviceGetName(100, device_info.current_device) + if name_result[0].value == 0: + device_info.device_name = name_result[1] + except: + pass + + # Get compute capability and architecture info + try: + major, minor = get_compute_capability_major_minor( + device_info.current_device + ) + + # Check if we successfully got the compute capability + if major is not None and minor is not None: + device_info.major_version = major + device_info.minor_version = minor + + arch_name, sm_arch, compatible_archs = _get_gpu_arch_info( + device_info.major_version, device_info.minor_version + ) + + device_info.arch_name = arch_name + device_info.sm_arch = sm_arch + device_info.compatible_archs = compatible_archs + + # Get memory info + try: + total_mem = cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_TOTAL_MEMORY, + device_info.current_device, + ) + if total_mem[0].value == 0: + device_info.memory_gb = total_mem[1] / ( + 1024 * 1024 * 1024 + ) # Convert to GB + except: + pass + + except Exception as e: + pass # Compute capability info will remain None + + except Exception as e: + device_info.error_message = str(e) + + return device_info + + +def checkCudaErrors(result): + """Check CUDA errors and provide detailed error messages.""" + if result[0].value: + error_code = result[0].value + error_name = _cudaGetErrorEnum(result[0]) + + raise DSLCudaRuntimeError(error_code, error_name) + + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + +# ============================================================================= +# Driver Helpers +# ============================================================================= + + +@lru_cache(maxsize=1) +def initialize_cuda_context(device_id: int = 0, flags: int = 0): + """ + Initializes the CUDA context for a specified device. + """ + # Initialize CUDA Driver API + _log().info(f"cuInit {flags}") + checkCudaErrors(cuda.cuInit(flags)) + # Retrieve handle for device + _log().info(f"cuDeviceGet {device_id}") + cuDevice = checkCudaErrors(cuda.cuDeviceGet(device_id)) + _log().info(f"{cuDevice} <-- cuDeviceGet") + # Create context + _log().info(f"cuCtxCreate {0} {cuDevice}") + if cuda.CUDA_VERSION >= 13000: + # Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2 + # and v3 API has been removed from CTK 13. + # See https://github.com/NVIDIA/cuda-python/pull/792 + context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice)) + else: + context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice)) + _log().info(f"{context} <-- cuCtxCreate") + + return context + + +def load_cubin_module(cubin_file): + """ + Loads a CUBIN file and returns the module. + """ + # Load CUBIN file as binary data + _log().info(f"read cubin {cubin_file}") + with open(cubin_file, "rb") as f: + cubin_data = f.read() + # Load module data + _log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}") + module = checkCudaErrors( + cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data) + ) + return module + + +def unload_cubin_module(module): + """ + Unloads a CUBIN module. + """ + _log().info(f"cuModuleUnload {module}") + checkCudaErrors(cuda.cuModuleUnload(module)) + + +def load_cubin_module_data(cubin_data): + """ + Loads a CUBIN from data and returns the module. + """ + # Load module data + _log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}") + module = checkCudaErrors( + cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data) + ) + return module + + +def get_kernel_function(module, kernel_name): + """ + Retrieves the kernel function from the module. + """ + _log().info(f"cuModuleGetFunction {module} {kernel_name}") + kernel = checkCudaErrors( + cuda.cuModuleGetFunction(module, bytes(kernel_name, "utf-8")) + ) + _log().info(f"{kernel} <-- cuModuleGetFunction") + return kernel + + +def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None): + """ + Launches the CUDA kernel. + """ + _log().info( + f"cuLaunchKernel {kernel} grid={grid_dims} blocks={block_dims} smem_size={smem_size} stream={stream} {kernel_args}" + ) + checkCudaErrors( + cuda.cuLaunchKernel( + kernel, + grid_dims[0], + grid_dims[1], + grid_dims[2], + block_dims[0], + block_dims[1], + block_dims[2], + smem_size, # Shared memory size + stream, + kernel_args, + 0, # Extra parameters + ) + ) + + +def stream_sync(stream): + """ + Synchronizes the CUDA stream. + """ + _log().info(f"cuStreamSynchronize {stream}") + checkCudaErrors(cuda.cuStreamSynchronize(stream)) + + +def stream_create(id=0): + """ + Creates the CUDA stream. + """ + _log().info(f"cuStreamCreate {id}") + stream = checkCudaErrors(cuda.cuStreamCreate(id)) + _log().info(f"{stream} <-- cuStreamCreate") + return stream + + +def stream_destroy(stream): + """ + Destroys the CUDA stream. + """ + _log().info(f"cuStreamDestroy {stream}") + checkCudaErrors(cuda.cuStreamDestroy(stream)) + + +def context_destroy(context): + """ + Destroys the CUDA context. + """ + _log().info(f"cuCtxDestroy {context}") + checkCudaErrors(cuda.cuCtxDestroy(context)) + + +def allocate(size_in_bytes: int, stream=None): + """ + Allocate device memory based on numpy host array size. + """ + _log().info("Allocate size_in_bytes=[%s] stream=[%s]", size_in_bytes, stream) + if stream is None: + device_memory = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes)) + else: + device_memory = checkCudaErrors(cuda.cuMemAllocAsync(size_in_bytes, stream)) + _log().info("Allocated [%s]", device_memory) + return device_memory + + +def deallocate(device_pointer, stream=None): + """ + Deallocate the specified device memory pointer. + """ + _log().info( + "Deallocate device_pointer=[%s] stream=[%s]", hex(int(device_pointer)), stream + ) + if stream is None: + checkCudaErrors(cuda.cuMemFree(device_pointer)) + else: + checkCudaErrors(cuda.cuMemFreeAsync(device_pointer, stream)) + + +def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None): + """ + Copy data from host to device memory. + """ + _log().info( + "Copy host-to-device host_pointer[%s] device_ptr=[%s] size_in_bytes=[%s] stream=[%s]", + hex(host_pointer), + hex(int(device_pointer)), + size_in_bytes, + stream, + ) + if stream is None: + checkCudaErrors(cuda.cuMemcpyHtoD(device_pointer, host_pointer, size_in_bytes)) + else: + checkCudaErrors( + cuda.cuMemcpyHtoDAsync(device_pointer, host_pointer, size_in_bytes, stream) + ) + + +def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None): + """ + Copy data from device to host memory. + """ + _log().info( + "Copy device-host-to device_pointer=[%s] host_pointer[%s] size_in_bytes=[%s] stream=[%s]", + hex(int(device_pointer)), + hex(host_pointer), + size_in_bytes, + stream, + ) + if stream is None: + checkCudaErrors(cuda.cuMemcpyDtoH(host_pointer, device_pointer, size_in_bytes)) + else: + checkCudaErrors( + cuda.cuMemcpyDtoHAsync(host_pointer, device_pointer, size_in_bytes, stream) + ) + + +def default_stream(): + return cuda.CUstream(0) + + +def get_driver_version(): + """ + Returns the CUDA driver version. + """ + return checkCudaErrors(cuda.cuDriverGetVersion()) + + +def set_kernel_attribute(kernel, attribute, value): + """ + Sets a CUDA kernel attribute. + """ + return checkCudaErrors(cuda.cuFuncSetAttribute(kernel, attribute, value)) + + +@JitArgAdapterRegistry.register_jit_arg_adapter(cuda.CUstream) +class StreamAdapter: + """ + Convert a CUDA stream to a stream representation for JIT arg generation. + """ + + def __init__(self, arg): + self._arg = arg + self._c_pointer = self._arg.getPtr() + + def __new_from_mlir_values__(self, values): + assert len(values) == 1 + return values[0] + + def __c_pointers__(self): + return [self._c_pointer] + + def __get_mlir_types__(self): + return [gpu.AsyncTokenType.get()] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..5addb275b12f2b18e109b0592a87f3044d2fe595 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import copy + +from . import cuda as cuda_helpers +from .tensor_descriptor import * +from ..common import * + + +def allocate(tensor: TensorDescriptor, stream=None): + """ + Allocates GPU memory + """ + if tensor._check_is_managed_by_framework(): + raise DSLRuntimeError( + "GPU tensors are managed by the framework and cannot be modified." + ) + if not tensor.device_pointer is None: + raise DSLRuntimeError("Tensor is already allocated on the device.") + + tensor.device_pointer = cuda_helpers.allocate(tensor.size_in_bytes, stream) + + log().info("Allocate done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + + +def deallocate(tensor: TensorDescriptor, stream=None): + """ + Deallocates GPU memory + """ + if tensor._check_is_managed_by_framework(): + raise DSLRuntimeError( + "GPU tensors are managed by the framework and cannot be modified." + ) + if tensor.device_pointer is None: + raise DSLRuntimeError("Tensor is not allocated on the device.") + + log().info( + "Deallocating done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer + ) + + cuda_helpers.deallocate(tensor.device_pointer, stream) + tensor.device_pointer = None + + +def copy_to_gpu(tensor: TensorDescriptor, do_allocate=True, stream=None): + """ + Copies data from host memory to the GPU memory. + If do_allocate is True, it first calls allocate + """ + log().info("copyin tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + if do_allocate: + allocate(tensor, stream) + cuda_helpers.memcpy_h2d( + tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream + ) + log().info("copyin done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + return tensor + + +def copy_from_gpu(tensor: TensorDescriptor, do_deallocate=True, stream=None): + """ + Copies data from GPU memory back to the host. + If do_deallocate is True, it calls deallocate + """ + log().info("copyout tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + if tensor._check_is_managed_by_framework(): + raise DSLRuntimeError( + "GPU tensors are managed by the framework and cannot be modified." + ) + if tensor.device_pointer is None: + raise DSLRuntimeError("Tensor is not allocated on the device.") + + cuda_helpers.memcpy_d2h( + tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream + ) + if do_deallocate: + deallocate(tensor, stream) + log().info("copyout done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + + +def to_gpu(tensor, stream=None) -> TensorDescriptor: + """ + Copies the tensor to the GPU memory from Host memory + """ + if isinstance(tensor, TensorDescriptor): + new_tensor = copy.copy(tensor) + copy_to_gpu(new_tensor, stream=stream) + return new_tensor + + if TensorDescriptor.can_transformed_to_dlpack(tensor): + new_tensor = TensorDescriptor(tensor) + copy_to_gpu(new_tensor, stream=stream) + return new_tensor + + raise DSLRuntimeError("Unsupported type") + + +def from_gpu(tensor, stream=None) -> TensorDescriptor: + """ + Copies the tensor to the GPU memory from Host memory + """ + if isinstance(tensor, TensorDescriptor): + new_tensor = copy.copy(tensor) + copy_from_gpu(new_tensor, stream=stream) + return new_tensor + + if TensorDescriptor.can_transformed_to_dlpack(tensor): + new_tensor = TensorDescriptor(tensor) + copy_from_gpu(new_tensor, stream=stream) + return new_tensor + + raise DSLRuntimeError("Unsupported type") diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py new file mode 100644 index 0000000000000000000000000000000000000000..168c2a9953f74b45cadfcbb6562f89d1bb35cd6d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides helper structs for dlpack. +DLPack is an open standard for in-memory tensor structures, enabling +seamless sharing of tensors across different frameworks. +Learn more at: https://github.com/dmlc/dlpack +""" + +import ctypes +import enum + + +class DLDeviceType(enum.IntEnum): + """Enums for device types based on the DLPack specification.""" + + kDLCPU = 1 + kDLGPU = 2 + kDLCPUPinned = 3 + + +class DLDataTypeCode: + """Enums for data type codes based on the DLPack specification. + + see https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h + """ + + kDLInt = 0 + kDLUInt = 1 + kDLFloat = 2 + kDLOpaqueHandle = 3 + kDLBfloat = 4 + kDLComplex = 5 + kDLBool = 6 + + +class DLDevice(ctypes.Structure): + """Structure representing the device information in DLPack.""" + + _fields_ = [ + ("device_type", ctypes.c_int), # kDLCPU, kDLGPU, etc. + ("device_id", ctypes.c_int), # Device ID (e.g., GPU ID) + ] + + +class DLDataType(ctypes.Structure): + """Structure representing the data type in DLPack.""" + + _fields_ = [ + ("code", ctypes.c_uint8), # Data type code (e.g., kDLFloat) + ("bits", ctypes.c_uint8), # Number of bits per value + ("lanes", ctypes.c_uint16), # Number of lanes + ] + + +class DLTensor(ctypes.Structure): + """Structure representing the DLTensor in DLPack.""" + + _fields_ = [ + ("data", ctypes.c_void_p), # Pointer to tensor data + ("device", DLDevice), # Device info + ("ndim", ctypes.c_int), # Number of dimensions + ("dtype", DLDataType), # Data type + ("shape", ctypes.POINTER(ctypes.c_int64)), # Shape of tensor + ("strides", ctypes.POINTER(ctypes.c_int64)), # Strides of tensor + ("byte_offset", ctypes.c_uint64), # Byte offset to tensor data + ] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..eb998d16d8fb4bcf592f17ce0f23a81d6e11bff6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides runtime utilities for JIT argument conversion in DSL. +""" + +from functools import wraps +from typing import get_origin + +# Local modules imports +from ..common import DSLRuntimeError +from ..typing import ( + Constexpr, + Int32, + Float32, + Boolean, +) + + +def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func): + """ + Check if the argument spec is a constexpr. + """ + + def _is_reserved_python_func_arg(arg_index, arg_name, func): + """ + Check if the argument is a reserved python function argument. + """ + + if arg_index != 0: + return False + + if arg_name == "self": + return True + + is_classmethod = isinstance(func, classmethod) or ( + hasattr(func, "__func__") and isinstance(func.__func__, classmethod) + ) + return arg_name == "cls" and is_classmethod + + return ( + _is_reserved_python_func_arg(arg_index, arg_name, owning_func) + or (isinstance(arg_spec, type) and issubclass(arg_spec, Constexpr)) + or (get_origin(arg_spec) is Constexpr) + ) + + +def is_argument_constexpr(arg, arg_spec, arg_name, arg_index, owning_func): + """ + Check if the argument is a constexpr. + """ + + def _is_type_argument(arg, arg_annotation): + """ + Check if the argument is a type argument like Type[X] + """ + + return isinstance(arg, type) and ( + arg_annotation is None or get_origin(arg_annotation) is type + ) + + return ( + is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func) + or _is_type_argument(arg, arg_spec) + or arg is None + ) + + +class JitArgAdapterRegistry: + """ + A registry to keep track of the JIT argument adapters. + + An adapter is a callable that converts a Python type to a type with following protocols supported: + - JitArgument + - DynamicExpression + The converted type can then be further processed by DSL to generate arguments for JIT functions. + """ + + # A dictionary with key=type and value=callable + jit_arg_adapter_registry = {} + + @classmethod + def register_jit_arg_adapter(cls, *dargs, **dkwargs): + """ + Register a JIT argument adapter callable + + This can be used as a decorator on any callable like: + + @register_jit_arg_adapter(my_py_type) + def my_adapter_for_my_py_type(arg): + ... + + @register_jit_arg_adapter(my_py_type) + class MyAdapterForMyPythonType: + ... + + The adapters are registered per type. If a type is already registerd, an error will be raised. + """ + + def decorator(*dargs, **dkwargs): + darg_python_ty = dargs[0] + + @wraps(darg_python_ty) + def wrapper(*args, **kwargs): + if len(args) != 1 or not callable(args[0]): + raise DSLRuntimeError( + "a callable must be provided for registering JIT argument adapter" + ) + adapter = args[0] + + if darg_python_ty in cls.jit_arg_adapter_registry: + raise DSLRuntimeError( + f"JIT argument adapter for {darg_python_ty} is already registered!", + context={ + "Registered adapter": cls.jit_arg_adapter_registry[ + darg_python_ty + ], + "Adapter to be registered": adapter, + }, + ) + cls.jit_arg_adapter_registry[darg_python_ty] = adapter + return adapter + + return wrapper + + if len(dargs) > 0: + return decorator(*dargs, **dkwargs) + else: + raise DSLRuntimeError( + "a Python type must be provided for registering JIT argument adapter" + ) + + @classmethod + def get_registered_adapter(cls, ty): + """ + Get the registered JIT argument adapter for the given type. + """ + return cls.jit_arg_adapter_registry.get(ty, None) + + +# ============================================================================= +# JIT Argument Adapters +# ============================================================================= + + +@JitArgAdapterRegistry.register_jit_arg_adapter(int) +@JitArgAdapterRegistry.register_jit_arg_adapter(float) +@JitArgAdapterRegistry.register_jit_arg_adapter(bool) +def _convert_python_scalar(arg): + """ + Convert a Python scalar to a DSL type. + """ + conversion_map = { + int: Int32, + float: Float32, + bool: Boolean, + } + return conversion_map.get(type(arg))(arg) + + +@JitArgAdapterRegistry.register_jit_arg_adapter(tuple) +@JitArgAdapterRegistry.register_jit_arg_adapter(list) +def _convert_python_sequence(arg): + """ + Go through each element in the sequence and convert it to a type that can be + further processed by DSL to generate the corresponding JIT argument(s). + """ + adapted_arg = [] + for elem in arg: + adapter = JitArgAdapterRegistry.get_registered_adapter(type(elem)) + if adapter is not None: + converted_elem = adapter(elem) + adapted_arg.append(converted_elem) + else: + # If no registered adapter is found, just return the original element + adapted_arg.append(elem) + + assert len(adapted_arg) == len(arg) + return type(arg)(adapted_arg) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py new file mode 100644 index 0000000000000000000000000000000000000000..1a992ef68293d6f969ab551b6321c3696c961037 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +# Helpers +import itertools, operator +import ctypes +from . import dlpack_types as _dpack +from .dlpack_runtime import ( + dlpack_to_tensor_desc, + get_tensor_desc_data_ptr, + get_tensor_desc_is_in_device, + get_tensor_desc_element_type, + get_tensor_desc_shape, + get_tensor_desc_stride, + get_tensor_desc_element_size_in_bytes, + get_tensor_desc_ndim, + get_tensor_desc_dtype_code, + get_tensor_desc_dtype_bits, + get_tensor_desc_device_type, + get_tensor_desc_device_id, +) + +from ..utils.logger import log +from ..common import * +from ..typing import ( + Boolean, + Float8E5M2, + Int64, + Int32, + Int16, + Int8, + Uint64, + Uint32, + Uint16, + Uint8, + Float64, + Float32, + Float16, + BFloat16, +) + + +class TensorDescriptor: + def __init__(self, tensor): + """Initialize with a tensor that supports the DLPack protocol. + + Args: + tensor: Any tensor object that implements __dlpack__ and __dlpack_device__ + """ + + self.tensor = tensor + self._capsule = dlpack_to_tensor_desc(tensor) + + self.data_ptr = get_tensor_desc_data_ptr(self._capsule) + self.device_type = get_tensor_desc_device_type(self._capsule) + self.device_type = _dpack.DLDeviceType(self.device_type) + + if self.device_type == _dpack.DLDeviceType.kDLGPU: + self.device_pointer = self.data_ptr + elif self.device_type == _dpack.DLDeviceType.kDLCPU: + self.device_pointer = None + else: + raise DSLRuntimeError( + f"DLPack device type is not supported {self.dl_tensor.device.device_type}" + ) + + log().info("TensorDescriptor is created = [%s]", self) + + @staticmethod + def can_transformed_to_dlpack(dl_tensor): + if not hasattr(dl_tensor, "__dlpack__") or not hasattr( + dl_tensor, "__dlpack_device__" + ): + return False + return True + + @property + def is_in_device(self): + """Check if the tensor is stored on a device.""" + return not self.device_pointer is None + + @property + def device_id(self): + """Return device id where tensor resides.""" + if self.is_in_device: + return get_tensor_desc_device_id(self._capsule) + return -1 + + @property + def element_type(self): + """Return the corresponding Python type based on DLPack dtype metadata.""" + str_element_type = get_tensor_desc_element_type(self._capsule) + dtype_map = { + # bool is 8bit from numpy and torch + "Bool": Boolean, + "Int64": Int64, + "Int32": Int32, + "Int16": Int16, + "Int8": Int8, + "UInt64": Uint64, + "UInt32": Uint32, + "UInt16": Uint16, + "UInt8": Uint8, + "Float64": Float64, + "Float32": Float32, + "Float16": Float16, + "BFloat16": BFloat16, + "Float8E5M2": Float8E5M2, + } + + if str_element_type not in dtype_map: + raise KeyError( + f"Unsupported element type in dlpack: '{str_element_type}'. Supported types are: {list(dtype_map.keys())}" + ) + + return dtype_map[str_element_type] + + @property + def shape(self): + """Return the shape of the tensor.""" + return get_tensor_desc_shape(self._capsule) + + @property + def rank(self): + """Return the rank of the tensor.""" + return get_tensor_desc_ndim(self._capsule) + + @property + def strides(self): + """Return the rank of the tensor.""" + return get_tensor_desc_stride(self._capsule) + + @property + def element_size_in_bytes(self): + """Calculate the element size in bytes of the DLPack tensor.""" + return get_tensor_desc_element_size_in_bytes(self._capsule) + + @property + def size_in_bytes(self): + """Calculate the total size in bytes of the DLPack tensor.""" + # Calculate the number of elements using the shape + ndim = get_tensor_desc_ndim(self._capsule) + shape = get_tensor_desc_shape(self._capsule) + num_elements = 1 + for i in range(ndim): + num_elements *= shape[i] + + # Total bytes + total_bytes = self.element_size_in_bytes * num_elements + return total_bytes + + def __str__(self): + """Return a compact string representation of the device_tensor with a tensor prefix.""" + # Extract shape + shape = "x".join(map(str, self.shape)) + + # Extract dtype + dtype_code = get_tensor_desc_dtype_code(self._capsule) + dtype_bits = get_tensor_desc_dtype_bits(self._capsule) + dtype = ( + f"i{dtype_bits}" + if dtype_code == _dpack.DLDataTypeCode.kDLInt + else f"f{dtype_bits}" + ) + + # Extract device + device_type = "cpu" if not self.is_in_device else "gpu" + + return f"tensor<{shape}x{dtype}>_{device_type}" + + def _check_is_managed_by_framework(self): + """ + Ensure the tensor is not managed by the framework (e.g., GPU tensor). + Raises an exception if the tensor is framework-managed. + """ + return self.device_type == _dpack.DLDeviceType.kDLGPU + + @staticmethod + def is_compatible(maybe_tensor_descriptor) -> bool: + """Check if the object is a TensorDescriptor or can be converted to one.""" + return isinstance( + maybe_tensor_descriptor, TensorDescriptor + ) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor) + + +def from_tensor(tensor) -> TensorDescriptor: + """Create a TensorDescriptor from a tensor object.""" + return TensorDescriptor(tensor) + + +def to_tensor(tensor_descriptor: TensorDescriptor): + """Return tensor object from tensor descriptor.""" + return tensor_descriptor.tensor diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..b46cff6de8176217f38af05b8604716c34aae009 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py @@ -0,0 +1,1962 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import ctypes +import numpy as np +import operator +from typing_extensions import deprecated +from functools import reduce +from typing import ( + Generic, + Protocol, + Union, + Any, + List, + Type, + TypeVar, + overload, + runtime_checkable, + get_origin, +) +from types import FunctionType +from dataclasses import dataclass +from abc import ABC, abstractmethod + +from .common import * +from .ast_helpers import const_expr +from ._mlir_helpers import arith as arith_helper, lru_cache_ir +from ._mlir_helpers.arith import ArithValue + +from .._mlir import ir +from .._mlir.extras import types as T +from .._mlir.dialects import arith, math + +# ============================================================================= +# Dynamic Expression Protocol +# ============================================================================= + + +@runtime_checkable +class DynamicExpression(Protocol): + """Protocol defining the interface for object holding dynamic values in the DSL. + + This protocol enables classes to represent dynamic values in the DSL. Classes implementing + this protocol can be used in JIT-compiled functions and dynamic value generation. + + It is required for custom data types to work correctly with following JIT features: + * as function argument to call another JIT function from JIT function + * as return value from JIT function + * for constructions like if-else, while-loop, etc. + + :param value: The MLIR operation result value to initialize the object with + :type value: ir.Value + + **Required Methods** + + * ``__extract_mlir_values__``: Extract MLIR values from the object + * ``__new_from_mlir_values__``: Create new instance from MLIR values + + **Implementation Example** + + To implement a custom data type that works with the DSL: + + .. code-block:: python + + class CustomData(metaclass=DslType): + def __init__(self, int_value): + self.int_value = int_value + + def __extract_mlir_values__(self): + return [self.int_value] + + def __new_from_mlir_values__(self, values): + return CustomData(values[0]) + + **Usage in JIT Functions** + + When used in JIT-compiled functions, the DSL automatically extracts MLIR values: + + .. code-block:: python + + @jit + def caller(): + x = CustomData(1) + return foo(x) + + This generates MLIR like: + + .. code-block:: mlir + + func @caller() -> i32 { + %0 = func.call @foo(%arg0) : (i32) -> i32 + return %0 : i32 + } + """ + + def __extract_mlir_values__(self): + """Extract MLIR values from this object. + + :return: List of MLIR values representing this object's data + :rtype: List[ir.Value] + """ + raise NotImplementedError + + def __new_from_mlir_values__(self, values): + """Create a new instance from MLIR values. + + :param values: List of MLIR values to construct the object from + :type values: List[ir.Value] + :return: New instance of the implementing class + :rtype: Any + """ + raise NotImplementedError + + +@runtime_checkable +class JitArgument(Protocol): + """ + Protocol class defining the interface for JIT function argument generation. + + This protocol enables classes to provide the necessary information for generating + JIT function arguments and allow the DSL JIT executor to call JIT compiled functions. + + **Required Methods** + + * ``__c_pointers__``: Returns ctypes pointers for runtime execution + * ``__get_mlir_types__``: Returns MLIR types for function definition + * ``__new_from_mlir_values__``: Creates new instances from MLIR values + + **Example** + + .. code-block:: python + + class CustomData: + def __init__(self, int_value, ...): + self.int_value = int_value + ... + + def __c_pointers__(self): + return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...] + + def __get_mlir_types__(self): + return [ir.IntegerType.get(32), ...] + + def __new_from_mlir_values__(self, values): + return CustomData(values[0], ...) + + @jit + def foo(x: CustomData): + a = x.int_value + 1 + ... + + # `CustomData` is an argument of `foo` + foo(CustomData(1, ...)) + + When called like ``y = foo(x)``, the following steps occur: + + 1. JIT compiler generates MLIR function definition using ``__get_mlir_types__`` + + .. code-block:: mlir + + func.func @foo(%arg0: i32, ...) { + ... + + return + } + + 2. JIT function can't use values from Python, so it needs to reconstruct the object from + MLIR values, a.k.a `%arg0`, with ``__new_from_mlir_values__`` and pass it to `foo`. + + Following code demonstrates how JIT compiler reconstructs the object and pass to Python. + + .. code-block:: python + + # Implementation of IR tracing + new_x = CustomData(ir.Value(%arg0), ...) + y = foo(new_x) + # `x.int_value` is %arg0 rather than `c1` defined by Python. + + 3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__`` + pointing to the underlying data object passing to JIT compiled function. + + .. code-block:: python + + jit_engine.invoke(compiled_foo, concat([x.__c_pointers__(), ...])) + """ + + def __c_pointers__(self): + """ + Generate a list of ctypes pointers for the current object. + + :return: List of ctypes pointers + :rtype: List[ctypes.c_void_p] + """ + raise NotImplementedError + + def __get_mlir_types__(self): + """ + Generate a list of MLIR types for the current object. + + :return: List of MLIR types + :rtype: List[ir.Type] + """ + raise NotImplementedError + + def __new_from_mlir_values__(self, values): + """ + Create a new object from MLIR values. + + :param values: List of MLIR values + :type values: List[ir.Value] + :return: A new object that represents the given MLIR values + :rtype: Any + """ + raise NotImplementedError + + +def get_c_pointers(obj): + """ + Given the `obj`, recursively go through it to extract all contained C pointers + """ + if hasattr(obj, "__c_pointers__"): + return obj.__c_pointers__() + elif isinstance(obj, (tuple, list)): + return sum((get_c_pointers(x) for x in obj), []) + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in get_c_pointers to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + return [] + + +def get_mlir_types(obj): + """ + Given the `obj`, recursively go through it to extract all contained MLIR types + """ + if hasattr(obj, "__get_mlir_types__"): + return obj.__get_mlir_types__() + elif hasattr(obj, "__extract_mlir_values__"): + return [v.type for v in obj.__extract_mlir_values__()] + elif isinstance(obj, ir.Value): + return [obj.type] + elif isinstance(obj, (tuple, list)): + return sum((get_mlir_types(x) for x in obj), []) + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in get_mlir_types to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + return [] + + +class DslType(type): + """Metaclass for all DSL types in the system. + + This metaclass provides type system infrastructure for DSL types, handling MLIR + type mappings and NumPy type conversions. + + All data types in DSL must provide the following methods: + + :param mlir_type: Corresponding MLIR type for this DSL type + :type mlir_type: Any, optional + :param is_abstract: Whether this type is abstract, defaults to False + :type is_abstract: bool, optional + + **Required Methods** + + * ``__str__`` (classmethod): Return string representation of the type + * ``__c_pointers__`` (optional): Return list of ctypes pointers of data used to invoke JIT function + * ``__get_mlir_types__``: Return list of MLIR types of the MLIR values contained in the instance + * ``__extract_mlir_values__``: Return list of MLIR values contained in the instance + * ``__new_from_mlir_values__``: Return a new instance from list of MLIR values + + **Attributes** + + :ivar _ir: MLIR provider + :vartype _ir: Any + :ivar _T: MLIR Type system provider + :vartype _T: Any + + **Properties** + + :property mlir_type: Returns the corresponding MLIR type for this DSL type + :type mlir_type: Any + + """ + + _is_abstract: bool + + def __new__(cls, name, bases, attrs, is_abstract=False, **kwargs): + new_cls = super().__new__(cls, name, bases, attrs) + + new_cls._is_abstract = is_abstract + + return new_cls + + @property + def is_abstract(cls): + return cls._is_abstract + + +class NumericMeta(DslType): + """Metaclass for numeric types providing width and numpy dtype information. + + :param width: Bit width of the numeric type, defaults to 8 + :type width: int + :param np_dtype: Corresponding NumPy dtype + :type np_dtype: numpy.dtype, optional + :param mlir_type: Corresponding MLIR type + :type mlir_type: Any, optional + :param is_abstract: Whether the type is abstract, defaults to False + :type is_abstract: bool, optional + + :ivar width: Bit width of the numeric type + :type width: int + :ivar _np_dtype: Corresponding NumPy dtype + :type _np_dtype: Union[numpy.dtype, None] + + :property numpy_dtype: Returns the corresponding NumPy dtype + :rtype numpy_dtype: numpy.dtype + """ + + width: int + + # Placeholder type + _mlir_type = Any + _np_dtype: Union[np.dtype, None] + + def __new__( + cls, + name, + bases, + attrs, + width=8, + np_dtype=None, + mlir_type=None, + is_abstract=False, + **kwargs, + ): + def _extract_mlir_values(self): + return [self.ir_value()] + + def _new_from_mlir_values(self, values: list) -> "Numeric": + res_ty = type(self) + return res_ty(values[0]) + + new_attrs = { + "__extract_mlir_values__": _extract_mlir_values, + "__new_from_mlir_values__": _new_from_mlir_values, + } + new_cls = super().__new__( + cls, + name, + bases, + new_attrs | attrs, + is_abstract=is_abstract, + **kwargs, + ) + + if mlir_type is not None: + new_cls._mlir_type = staticmethod(mlir_type) + + new_cls.width = width + new_cls._np_dtype = np_dtype + return new_cls + + @property + def numpy_dtype(cls): + return cls._np_dtype + + @property + def is_integer(cls) -> bool: ... + + @property + def is_float(cls) -> bool: ... + + def is_same_kind(cls, other: Type) -> bool: + return cls.is_integer == other.is_integer or cls.is_float == other.is_float + + @staticmethod + def from_python(value: Any) -> Type["Numeric"]: + """ + Deduce the DSL type from a Python value. + """ + if isinstance(value, int): + return Int32 + elif isinstance(value, float): + return Float32 + elif isinstance(value, bool): + return Boolean + raise DSLRuntimeError( + f"Could not deduce Type[Numeric] from python value: {value} :{type(value)}" + ) + + @property + def mlir_type(cls): + return cls._mlir_type() # type: ignore + + +Value = TypeVar("Value") + + +def cast(obj: Union[bool, int, float, Value], type_: Type["Numeric"]) -> "Numeric": + """Cast an object to the specified numeric type. + + :param obj: Object to be cast + :type obj: Union[bool, int, float, Value] + :param type_: Target numeric type + :type type_: Type[Numeric] + :raises TypeError: If casting to an abstract type or unsupported type conversion + :return: Object cast to the target numeric type + :rtype: Numeric + + Example:: + >>> x = cast(5, Int32) # Cast integer to Int32 + >>> y = cast(3.14, Float32) # Cast float to Float32 + """ + if type_.is_abstract: + if not isinstance(obj, type_): + raise TypeError( + f"can't cast {obj} to {type_}. Pass in concrete type instead, " + "e.g. Int32, Float32, etc." + ) + # If target_type is abstract, and value is instance of target_type, + # then we can return value as is + else: + # Implicit cast based on using annotation type + obj = type_(obj) + return obj + + +# Option 1: use ir.Value as base +# class IntegerMeta(DslType, type(ir.Value)): +class IntegerMeta(NumericMeta): + """Metaclass for integer types providing signedness information. + + :param width: Bit width of the integer type, defaults to 32 + :type width: int + :param signed: Whether the integer type is signed, defaults to True + :type signed: bool + :param mlir_type: Corresponding MLIR type, defaults to None + :type mlir_type: Any, optional + + :ivar signed: Whether the integer type is signed + :vartype signed: bool + :ivar arith: Arithmetic operations interface + :vartype arith: Any + """ + + signed: bool + + def __new__( + cls, + name, + bases, + attrs, + width=32, + signed=True, + mlir_type=None, + is_abstract=False, + ): + if width == 1: + np_dtype = np.bool_ + elif width == 128: + np_dtype = None + elif signed: + np_dtype = getattr(np, f"int{width}") + else: + np_dtype = getattr(np, f"uint{width}") + + def _c_pointers(self): + if width == 1: + c_value = ctypes.c_bool(self.value) + elif signed: + c_value = getattr(ctypes, f"c_int{width}")(self.value) + else: + c_value = getattr(ctypes, f"c_uint{width}")(self.value) + + return [ctypes.cast(ctypes.pointer(c_value), ctypes.c_void_p)] + + new_attrs = { + "__c_pointers__": _c_pointers, + } + new_cls = super().__new__( + cls, name, bases, attrs | new_attrs, width, np_dtype, mlir_type, is_abstract + ) + new_cls.signed = signed + return new_cls + + def __str__(cls): + return f"{cls.__name__}" + + @property + def is_integer(cls) -> bool: + return True + + @property + def is_float(cls) -> bool: + return False + + @property + def zero(cls) -> int: + return 0 + + @property + def min(cls) -> int: + if cls.signed: + return -(2 ** (cls.width - 1)) + else: + return 0 + + @property + def max(cls) -> int: + if cls.signed: + return 2 ** (cls.width - 1) - 1 + else: + return 2**cls.width - 1 + + def recast_width(cls, width): + type_map = { + 8: Int8, + 16: Int16, + 32: Int32, + 64: Int64, + 128: Int128, + } + if width not in type_map: + raise TypeError(f"Unsupported width: {width}") + return type_map[width] + + +class FloatMeta(NumericMeta): + """Metaclass for floating-point types. + + This metaclass provides type system infrastructure for floating-point types in the DSL, + handling MLIR type mappings and NumPy type conversions. + + :param width: Bit width of the float type, defaults to 32 + :type width: int + :param mlir_type: Corresponding MLIR type, defaults to None + :type mlir_type: Any, optional + :param is_abstract: Whether this is an abstract base class, defaults to False + :type is_abstract: bool, optional + + :ivar _arith: Arithmetic operations interface + :vartype _arith: Any + """ + + _exponent_width: int + _mantissa_width: int + + def __new__(cls, name, bases, attrs, width=32, mlir_type=None, is_abstract=False): + np_dtype = getattr(np, name.lower(), None) + new_cls = super().__new__( + cls, name, bases, attrs, width, np_dtype, mlir_type, is_abstract + ) + # Extract exponent and mantissa bits from class name if it follows Float pattern + # For example: Float8E4M3 -> exponent_width=4, mantissa_width=3 + import re + + if not is_abstract: + match = re.match(r"Float(\d+)E(\d+)M(\d+)(?:.*)", name) + if match: + exp_bits = int(match.group(2)) + mant_bits = int(match.group(3)) + + # Store extracted values as class attributes + new_cls._exponent_width = exp_bits + new_cls._mantissa_width = mant_bits + # Don't have 1-to-1 mapping of narrow precision types like bfloat16, tfloat32, etc. + return new_cls + + def __str__(cls): + return f"{cls.__name__}" + + @property + def is_integer(cls) -> bool: + return False + + @property + def is_float(cls) -> bool: + return True + + @property + def zero(cls) -> float: + return 0.0 + + @property + def inf(cls) -> float: + return float("inf") + + @property + def nan(cls) -> float: + return float("nan") + + @property + def exponent_width(cls) -> int: + return cls._exponent_width + + @property + def mantissa_width(cls) -> int: + return cls._mantissa_width + + def recast_width(cls, width): + type_map = { + 16: Float16, + 32: Float32, + 64: Float64, + } + if width not in type_map: + raise TypeError(f"Unsupported width: {width}") + return type_map[width] + + +def _arith_signless_to_int(a, target_type): + # is_signed: sign of result type + if target_type.width > a.type.width: + # arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL + if target_type.signed and a.type.width > 1: + return arith.extsi(target_type.mlir_type, a) + else: + return arith.extui(target_type.mlir_type, a) + elif target_type.width < a.type.width: + return arith.trunci(target_type.mlir_type, a) + else: + return a + + +def _binary_op_type_promote(a, b, promote_bool: bool = False): + """Promote two numeric operands following type promotion rules. + + :param a: First numeric operand + :type a: Numeric + :param b: Second numeric operand + :type b: Numeric + :param promote_bool: Whether to promote boolean types to Int32 for arithmetic operations, defaults to False + :type promote_bool: bool, optional + :raises ValueError: If implicit float promotion is not supported between the given types + :return: Tuple containing promoted operands and their resulting type + :rtype: tuple[Numeric, Numeric, Type[Numeric]] + + Type promotion rules: + 1. If operands are same type and not bools needing promotion: + - No promotion needed, return original types + 2. If either operand is float: + a. If one is float and one is int: + - Convert int to the float type + b. If both are float: + - Promote to higher precision float if width >= 16 + - For same width, promote to more general type (Float32 over TFloat32) + - Otherwise raise ValueError for unsupported promotion + 3. Otherwise, both operands are integers. Integer promotion rules: + a. If promote_bool is True and either operand is bool: + - Promote bool to Int32 for arithmetic operations + + Exceptions for numpy dtype casting: + - array(dtype=np.bool_) + array(dtype=np.bool_) -> array(dtype=np.bool_) + + What is not supported: + - promotion with narrow precision float types which requires explicit cast by user + """ + a_type = a.dtype + b_type = b.dtype + + # Early return for same types (except when they're bools that need promotion) + if a_type == b_type and not (promote_bool and a_type is Boolean): + return a, b, a_type + + # Handle floating point promotions + if a_type.is_float or b_type.is_float: + # Get highest precision float type based on bitwidth + a_width = getattr(a_type, "width", 0) + b_width = getattr(b_type, "width", 0) + + # If one type is integer, convert it to the float type + if a_type.is_float and not b_type.is_float: + b_type = a_type.recast_width(max(a_width, b_width)) + elif b_type.is_float and not a_type.is_float: + a_type = b_type.recast_width(max(a_width, b_width)) + + # Both are float types - handle precision promotion + if a_width > b_width and a_width >= 16: + res_type = a_type + elif b_width > a_width and b_width >= 16: + res_type = b_type + elif a_width == b_width: + # Same bitwidth - handle special cases like TFloat32 -> Float32 and BFloat16 -> Float16 + if a_type is Float64 or b_type is Float64: + res_type = Float64 + elif a_type is Float32 or b_type is Float32: + res_type = Float32 + elif a_type is Float16 or b_type is Float16: + res_type = Float16 + else: + raise ValueError( + f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly" + ) + else: + raise ValueError( + f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly" + ) + + # Only convert if type is different + new_a = a.to(res_type) if a.dtype != res_type else a + new_b = b.to(res_type) if b.dtype != res_type else b + return new_a, new_b, res_type + + # Handle bool promotion for arithmetic operations + if promote_bool: + if a_type is Boolean and b_type is Boolean: + # Only promote to Int32 when both are bool + a = a.to(Int32) + b = b.to(Int32) + a_type = b_type = a.dtype + + # If both were bools, they're now same type (Int32) + if a_type == b_type: + return a, b, a_type + + # Same type, no promotion needed + if a_type == b_type: + return a, b, a_type + + a_signed = a_type.signed + b_signed = b_type.signed + a_width = a_type.width + b_width = b_type.width + + # Mixed signedness case + if a_signed != b_signed: + unsigned_type = a_type if not a_signed else b_type + signed_type = a_type if a_signed else b_type + unsigned_width = a_width if not a_signed else b_width + + if unsigned_width >= signed_type.width: + # Promote both to unsigned of larger width + res_type = unsigned_type + else: + # Promote both to signed of larger width + res_type = signed_type + + new_a = a.to(res_type) if a.dtype != res_type else a + new_b = b.to(res_type) if b.dtype != res_type else b + return new_a, new_b, res_type + + # Same signedness, different width - promote to larger width + if a_width >= b_width: + return a, b.to(a.dtype), a.dtype + else: + return a.to(b.dtype), b, b.dtype + + +def _binary_op(op, promote_operand=True, promote_bool=False, flip=False): + """Wrapper for binary operations on Numeric types. + + This wrapper handles type promotion, operation execution, and result type determination + for binary operations between Numeric types. + + :param op: The binary operation to perform (e.g., operator.add, operator.sub) + :type op: callable + :param emitter: Function that emits the MLIR operation for dynamic values + :type emitter: callable + :param promote_operand: Whether to promote operands to the same type, defaults to True + :type promote_operand: bool, optional + :param promote_bool: Whether to promote boolean results to Boolean type, defaults to False + :type promote_bool: bool, optional + :param flip: Whether to flip the operands when calling the operation, defaults to False + :type flip: bool, optional + + :raises TypeError: When an unsupported operation is attempted on specific numeric types + + .. note:: + Not all operations are supported for all numeric types. In particular: + + - Subtraction is not fully supported for Integer types + - Multiplication, floor division, and modulo operations may have limited support + - Division (truediv) with integer types is not fully supported and converts to Float32 + """ + + def wrapper(lhs, rhs, *, loc=None, ip=None): + orig_lhs_type = type(lhs) + orig_rhs_type = type(rhs) + + # When called directly with self and other + ty = type(lhs) + # Canonicalize to Numeric type for promotion + if not isinstance(rhs, Numeric): + if not isinstance(rhs, (ArithValue, int, float, bool)): + # This allows rhs class to implement __rmul__ + return NotImplemented + + if isinstance(rhs, ArithValue): + if isinstance(rhs.type, ir.VectorType): + return NotImplemented + + rhs = as_numeric(rhs) + + # default result type to left-hand-side + res_type = ty + + if promote_operand: + lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool) + else: + rhs = ty(rhs) + + if op in ( + operator.lt, + operator.le, + operator.gt, + operator.ge, + operator.eq, + operator.ne, + ): + res_type = Boolean + elif op == operator.truediv and isinstance(lhs, Integer): + res_type = Float32 + elif promote_bool and orig_lhs_type == Boolean and orig_rhs_type == Boolean: + res_type = Boolean + + if isinstance(lhs.value, ArithValue) and isinstance(lhs, Integer): + lhs_val = lhs.value.with_signedness(lhs.signed) + else: + lhs_val = lhs.value + + if isinstance(rhs.value, ArithValue) and isinstance(rhs, Integer): + rhs_val = rhs.value.with_signedness(rhs.signed) + else: + rhs_val = rhs.value + + if flip: + lhs_val, rhs_val = rhs_val, lhs_val + + # Check if the operation is supported by the operands + res_val = op(lhs_val, rhs_val) + return res_type(res_val, loc=loc, ip=ip) + + return wrapper + + +class Numeric(metaclass=NumericMeta, is_abstract=True): + """Base class for all numeric types in the DSL. + + This class provides the foundation for both Integer and Float types, + implementing basic arithmetic operations. + + :param value: The value to store in the numeric type + :type value: Union[bool, int, float, Value] + + :ivar value: The stored numeric value + :vartype value: Union[bool, int, float, Value] + """ + + def __init__(self, value: Union[bool, int, float, Value], *, loc=None, ip=None): + self.value = value + + def __str__(self) -> str: + # Use member's pretty-str method if member object has method. + # This can be extended in future to have better support for IDE, jupyter notebook, etc. + pretty_str = getattr(self.value, "pretty_str", None) + if pretty_str is not None: + return pretty_str() + else: + return "?" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({repr(self.value)})" + + def __hash__(self): + return hash(type(self).__class__) ^ hash(self.value) + + @property + def dtype(self) -> Type["Numeric"]: + return type(self) + + @overload + def to(self, dtype: Type["Numeric"], *, loc=None, ip=None) -> "Numeric": ... + + @overload + def to(self, dtype: Type[int], *, loc=None, ip=None) -> int: ... + + @overload + def to(self, dtype: Type[float], *, loc=None, ip=None) -> float: ... + + @overload + def to(self, dtype: Type[bool], *, loc=None, ip=None) -> bool: ... + + @overload + def to(self, dtype: Type[ir.Value], *, loc=None, ip=None) -> ir.Value: ... + + def to(self, dtype: Type, *, loc=None, ip=None): + """Convert this numeric value to another numeric type. + + If the target type is the same as the current type, returns self. + Otherwise, creates a new instance of the target type with the same value. + + :param dtype: The target numeric type to convert to + :type dtype: Union[Type["Numeric"], Type[int], Type[float], Type[bool]] + :return: A new instance of the target type, or self if types match + :rtype: Numeric + :raises TypeError: If trying to convert an MLIR value to a static Python type + :raises TypeError: If trying to convert to unsupported float types like Float8E4M3, + Float8E4M3B11FNUZ, Float4E2M1FN, Float6E3M2FN, or Float6E2M3FN + + .. note:: + + Unsupported destination float types: + - Float8E4M3 + - Float8E4M3B11FNUZ + - Float4E2M1FN + - Float6E3M2FN + - Float6E2M3FN + + Example:: + + .. code-block:: python + + # Convert between DSL numeric types + x = Int32(5) + y = x.to(Float32) # Converts to Float32(5.0) + + # Convert to Python primitive types + # They are considered as static values at JIT time + z = x.to(int) # Returns Python int 5 + w = y.to(float) # Returns Python float 5.0 + + # This will raise a ValueError + mlir_val = arith.constant(T.i32(), 42) + num = Int32(mlir_val) + num.to(int) # ValueError: unable to convert MLIR value to static type: + """ + if dtype in _unsupported_dst_float_types: + raise TypeError(f"Unsupported destination float type: {dtype}") + + if isinstance(dtype, type(self)): + return self + elif isinstance(dtype, NumericMeta): + return dtype(self) + elif dtype is ir.Value: + if isinstance(self.value, (int, float, bool)): + res = arith_helper.const( + self.value, self.dtype.mlir_type, loc=loc, ip=ip + ) + elif isinstance(self.value, ir.Value): + res = self.value + else: + raise ValueError( + f"cannot convert {type(self)} to {dtype}, " + f"self.value is {self.value.type}" + ) + + if not isinstance(res, ArithValue): + raise ValueError(f"Expected ArithValue, got {type(res)} as {res.type}") + + return res.with_signedness(getattr(type(self), "signed", None)) + elif dtype in (int, float, bool): + if isinstance(self.value, ir.Value): + raise ValueError( + f"unable to convert {self.value} to static type: {dtype}" + ) + return dtype(self.value) + else: + raise ValueError(f"unable to convert {type(self)} to {dtype}") + + def ir_value(self, *, loc=None, ip=None) -> ir.Value: + return self.to(ir.Value, loc=loc, ip=ip) + + @property + def zero(self) -> "Numeric": ... + + def __dsl_not__(self, *, loc=None, ip=None): + """DSL implementation of Python's `not` operator. + + Returns True if the value is equal to zero, False otherwise. + This matches Python's behavior where any non-zero number is considered True. + + :param loc: The source location information, defaults to None + :type loc: Optional[Location] + :param ip: The insertion point for the operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: The result of the logical not operation + :rtype: Boolean + """ + if isinstance(self.value, (int, float, bool)): + return not self.value + else: + ty = type(self) + zero_val = arith.constant(ty.mlir_type, ty.zero) + return self.__eq__(ty(zero_val), loc=loc, ip=ip) + + def __dsl_and__(self, other, *, loc=None, ip=None): + """DSL implementation of Python's `and` operator. + + Returns the second operand if the first is truthy, otherwise returns the first operand. + A numeric value is considered truthy if it is non-zero. + + :param other: The right-hand operand + :type other: Numeric + :param loc: The source location information, defaults to None + :type loc: Optional[Location] + :param ip: The insertion point for the operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: The result of the logical and operation + :rtype: Boolean + + Example:: + + 5 and 3 -> 3 + 0 and 3 -> 0 + 3 and 0 and ... -> 0 + """ + is_true = self.__dsl_bool__(loc=loc, ip=ip) + + def and_op(lhs, rhs): + if isinstance(lhs, (int, float, bool)): + if isinstance(rhs, (int, float, bool)): + return lhs and rhs + else: + lhs = arith.constant(rhs.type, lhs) + return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip) + else: + if isinstance(rhs, (int, float, bool)): + rhs = arith.constant(lhs.type, rhs) + return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip) + else: + return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip) + + return _binary_op(and_op, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __dsl_or__(self, other, *, loc=None, ip=None): + """DSL implementation of Python's `or` operator. + + Returns the first operand if it is truthy, otherwise returns the second operand. + A numeric value is considered truthy if it is non-zero. + + :param other: The right-hand operand + :type other: Numeric + :param loc: The source location information, defaults to None + :type loc: Optional[Location] + :param ip: The insertion point for the operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: The result of the logical or operation + :rtype: Boolean + + Example:: + + 5 or 3 -> 5 + 0 or 3 -> 3 + 3 or 0 -> 3 + """ + is_true = self.__dsl_bool__(loc=loc, ip=ip) + + def or_op(lhs, rhs): + if isinstance(lhs, (int, float, bool)): + if isinstance(rhs, (int, float, bool)): + return lhs or rhs + else: + lhs = arith.constant(rhs.type, lhs) + return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip) + else: + if isinstance(rhs, (int, float, bool)): + rhs = arith.constant(lhs.type, rhs) + return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip) + else: + return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip) + + return _binary_op(or_op, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __dsl_bool__(self, *, loc=None, ip=None) -> "Boolean": + """DSL implementation of Python's __bool__ method. + + Returns a Boolean indicating whether this value is considered truthy. + For numeric types, returns True if the value is non-zero. + + :param loc: The source location information, defaults to None + :type loc: Optional[Location] + :param ip: The insertion point for the operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: True if this value is truthy (non-zero), False otherwise + :rtype: Boolean + """ + zero = type(self).zero + return self.__ne__(zero, loc=loc, ip=ip) + + def __bool__(self): + if isinstance(self.value, (int, float, bool)): + return bool(self.value) + else: + raise DSLRuntimeError( + f"Unable to convert dynamic `{type(self).__name__}` value to bool at compile time.", + suggestion=[ + "Decorate the parent function with `jit` decorator and with `preprocess` enabled.", + "Ensure not using patterns that DSL does not support.", + "Otherwise, please file a bug report.", + ], + ) + + def __index__(self): + if isinstance(self.value, (int, float, bool)): + return self.value + else: + raise DSLRuntimeError( + f"'{type(self.value)}' object cannot be interpreted as an integer", + suggestion="Mark the loop as dynamic with `dynamic_expr` or `range_dynamic` and decorate the parent function with `jit` decorator", + ) + + def __neg__(self, *, loc=None, ip=None): + if isinstance(self, (bool, int, float)): + return type(self)(-self.value) # type: ignore + else: + return type(self)(-self.value, loc=loc, ip=ip) # type: ignore + + @staticmethod + def _from_python_value(value): + if isinstance(value, Numeric): + return value + + if isinstance(value, bool): + res_type = Boolean + elif isinstance(value, int): + res_type = Int32 + elif isinstance(value, float): + res_type = Float32 + elif isinstance(value, ArithValue): + res_type = Numeric.from_mlir_type(value.type) + else: + raise ValueError( + f"unable to convert {value} in type {type(value)} to Numeric" + ) + return res_type(value) + + def __add__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.add, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __sub__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.sub, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __mul__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.mul, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __floordiv__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.floordiv, promote_bool=True)( + self, other, loc=loc, ip=ip + ) + + def __truediv__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.truediv, promote_bool=True)( + self, other, loc=loc, ip=ip + ) + + def __mod__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.mod, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __radd__(self, other, *, loc=None, ip=None) -> "Numeric": + return self.__add__(other, loc=loc, ip=ip) + + def __rsub__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.sub, promote_bool=True, flip=True)( + self, other, loc=loc, ip=ip + ) + + def __rmul__(self, other, *, loc=None, ip=None) -> "Numeric": + return self.__mul__(other, loc=loc, ip=ip) + + def __rfloordiv__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.floordiv, promote_bool=True, flip=True)( + self, other, loc=loc, ip=ip + ) + + def __rtruediv__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.truediv, promote_bool=True, flip=True)( + self, other, loc=loc, ip=ip + ) + + def __rmod__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.mod, promote_bool=True, flip=True)( + self, other, loc=loc, ip=ip + ) + + def __eq__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.eq)(self, other, loc=loc, ip=ip) # type: ignore + + def __ne__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.ne)(self, other, loc=loc, ip=ip) # type: ignore + + def __lt__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.lt)(self, other, loc=loc, ip=ip) # type: ignore + + def __le__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.le)(self, other, loc=loc, ip=ip) # type: ignore + + def __gt__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.gt)(self, other, loc=loc, ip=ip) # type: ignore + + def __ge__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.ge)(self, other, loc=loc, ip=ip) # type: ignore + + def __pow__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.pow)(self, other, loc=loc, ip=ip) # type: ignore + + def __c_pointers__(self): + raise ValueError( + f"only support built-in types: bool, (u)int{8, 16, 32, 64}, float{32, 64}, but got {type(self)}" + ) + + def __get_mlir_types__(self): + return [type(self).mlir_type] + + @staticmethod + def from_mlir_type(mlir_type): + type_map = { + T.bool(): Boolean, + T.f64(): Float64, + T.f32(): Float32, + T.tf32(): TFloat32, + T.f16(): Float16, + T.bf16(): BFloat16, + T.i(128): Int128, + T.i64(): Int64, + T.i32(): Int32, + T.i16(): Int16, + T.i8(): Int8, + T.si(128): Int128, + T.si64(): Int64, + T.si32(): Int32, + T.si16(): Int16, + T.si8(): Int8, + T.ui(128): Uint128, + T.ui64(): Uint64, + T.ui32(): Uint32, + T.ui16(): Uint16, + T.ui8(): Uint8, + T.f8E5M2(): Float8E5M2, + T.f8E4M3(): Float8E4M3, + T.f8E4M3FN(): Float8E4M3FN, + T.f8E4M3B11FNUZ(): Float8E4M3B11FNUZ, + T.f4E2M1FN(): Float4E2M1FN, + T.f6E2M3FN(): Float6E2M3FN, + T.f6E3M2FN(): Float6E3M2FN, + T.f8E8M0FNU(): Float8E8M0FNU, + } + if mlir_type not in type_map: + raise DSLRuntimeError(f"Unsupported DSL type: {mlir_type}") + return type_map[mlir_type] + + +def as_numeric(obj: Union[bool, int, float, ir.Value, Numeric]) -> Numeric: + """Convert a Python primitive value to a Numeric type. + + :param obj: Python primitive value to convert + :type obj: Union[bool, int, float] + :return: The converted Numeric object + :rtype: Numeric + + Example:: + + .. code-block:: python + + x = as_numeric(5) # Converts to Int32 + y = as_numeric(3.14) # Converts to Float32 + z = as_numeric(True) # Converts to Boolean + """ + if isinstance(obj, Numeric): + return obj + return Numeric._from_python_value(obj) + + +class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True): + """A class representing integer values with specific width and signedness. + + This class provides functionality to create and manipulate integer values with + configurable width and signedness. It supports conversion from various input types + including Python scalars, MLIR Values, and other numeric types. + + :param x: The input value to convert to this integer type + :type x: Union[bool, int, float, ir.Value, Integer, Float] + + :return: A new Integer instance with the converted value + :rtype: Integer + + :raises AssertionError: If the type's numpy_dtype is None + :raises NotImplementedError: If converting between different Integer types + :raises ValueError: If the input type is not supported for conversion + :raises OverflowError: If converting float infinity to integer + + Type conversion behavior: + + * Python scalars (bool, int, float): + * Converted through numpy dtype casting + * NaN and infinity values are rejected + * Example: Int8(256) -> -256 (overflow behavior) + + * MLIR Value with IntegerType: + * Width differences handled by signless to signed/unsigned conversion + * Example: i8 -> i8/ui8 depending on target type + + * MLIR Value with FloatType: + * Uses MLIR float-to-int conversion + * NaN and infinity values is undefined behavior + * Example: f32 -> i32/ui32 depending on target type + + * Integer: + * Uses MLIR float-to-int conversion or numpy dtype casting + * Example: Int32(Int32(5)) => 5 + + * Float: + * Uses MLIR float-to-int conversion + * Example: Int32(Float(5.7)) -> 5 + + Example usage: + + .. code-block:: python + + x = Int32(5) # From integer + y = Int32(True) # From boolean + z = Int32(3.7) # From float (truncates) + w = Int32(x) # From same Integer type + c5 = arith.constant(5, T.i32()) + a = Int32(c5) # Treat c5 as int32 bitwise + """ + + def __init__(self, x, *, loc=None, ip=None): + ty = type(self) + + if isinstance(x, (bool, int, float)): + # Add check for NaN before numpy conversion + if isinstance(x, float): + if np.isnan(x): + raise ValueError("Cannot convert float NaN to integer") + elif np.isinf(x): + raise OverflowError("Cannot convert float infinity to integer") + + np_dtype = ty.numpy_dtype + assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}" + x_val = int(np.array(x).astype(np_dtype)) + elif type(x) == ty: + x_val = x.value + elif isinstance(x, ir.Value): # type: ignore + x_val = x + if isinstance(x.type, ir.IntegerType): # type: ignore + if x.type.width != ty.width: + # signless -> (u)int + x_val = _arith_signless_to_int(x, ty) + elif isinstance(x.type, ir.FloatType): # type: ignore + # float -> (u)int + x_val = arith_helper.fptoi(x, ty.signed, ty.mlir_type, loc=loc, ip=ip) + elif isinstance(x, Integer): + if isinstance(x.value, ir.Value): + x_val = arith_helper.int_to_int(x.ir_value(), ty) + else: + # For non-MLIR values, use numpy casting + src_val = np.array(x.value, dtype=type(x).numpy_dtype) + x_val = int(src_val.astype(ty.numpy_dtype)) + elif isinstance(x, Float): + # float -> int is handled by Integer.__init__ recursively + Integer.__init__(self, x.value) + return + else: + raise DSLRuntimeError(f"{x} to integer conversion is not supported") + + super().__init__(x_val) + + def __invert__(self, *, loc=None, ip=None): + res_type = type(self) + return res_type(self.ir_value(loc=loc, ip=ip).__invert__(loc=loc, ip=ip)) + + def __lshift__(self, other, *, loc=None, ip=None): + return _binary_op(operator.lshift)(self, other, loc=loc, ip=ip) + + def __rlshift__(self, other, *, loc=None, ip=None): + other_ = as_numeric(other) + if not isinstance(other_, Integer): + raise ValueError(f"Cannot left shift {other_} with {self}") + return other_.__lshift__(self, loc=loc, ip=ip) + + def __rshift__(self, other, *, loc=None, ip=None): + return _binary_op(operator.rshift)(self, other, loc=loc, ip=ip) + + def __rrshift__(self, other, *, loc=None, ip=None): + other_ = as_numeric(other) + if not isinstance(other_, Integer): + raise ValueError(f"Cannot right shift {other_} with {self}") + return other_.__rshift__(self, loc=loc, ip=ip) + + def __and__(self, other, *, loc=None, ip=None): + return _binary_op(operator.and_)(self, other, loc=loc, ip=ip) + + def __rand__(self, other, *, loc=None, ip=None): + return self.__and__(other, loc=loc, ip=ip) + + def __or__(self, other, *, loc=None, ip=None): + return _binary_op(operator.or_)(self, other, loc=loc, ip=ip) + + def __ror__(self, other, *, loc=None, ip=None): + return self.__or__(other, loc=loc, ip=ip) + + def __xor__(self, other, *, loc=None, ip=None): + return _binary_op(operator.xor)(self, other, loc=loc, ip=ip) + + def __rxor__(self, other, *, loc=None, ip=None): + return self.__xor__(other, loc=loc, ip=ip) + + +class Float(Numeric, metaclass=FloatMeta, mlir_type=T.f32, is_abstract=True): + """A class representing floating-point values. + + :param x: The input value to convert to this float type. + :type x: Union[bool, int, float, ir.Value, Integer, Float] + + Type conversion behavior: + + 1. Python scalars (bool, int, float): + - Converted through numpy dtype casting + - Example: Float32(1.7) -> 1.7 + + 2. MLIR Value with FloatType: + - If width differs: converts between float types + - Example: f16 -> f32 + + 3. MLIR Value with IntegerType: + - Not supported, raises ValueError + + 4. Integer: + - Converts using MLIR int-to-float operation + - Example: Float32(Int32(5)) -> 5.0 + + 5. Float: + - Direct conversion between float types + - Example: Float32(Float32(1.5)) -> 1.5 + + .. note:: + The following narrow precision types are only supported in device code: + + 8-bit float types: + - Float8E5M2 + - Float8E4M3 + - Float8E4M3FN + - Float8E8M0FNU + - Float8E4M3B11FNUZ + + 6-bit float types: + - Float6E3M2FN + - Float6E2M3FN + + 4-bit float types: + - Float4E2M1FN + + Narrow precision types and special floating-point formats support matrix on device: + + :raises AssertionError: If the type's numpy_dtype is None + :raises ValueError: If conversion from the input type is not supported + """ + + def __init__(self, x, *, loc=None, ip=None): + ty = type(self) + + if isinstance(x, (bool, int, float)): # type: ignore + # Why we need to convert x to with numpy? + # np_dtype = ty.numpy_dtype + # assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}" + # x = float(np.array(x).astype(np_dtype)) + super().__init__(float(x)) + elif isinstance(x, ir.Value): # type: ignore + if isinstance(x.type, ir.IntegerType): # type: ignore + raise DSLRuntimeError("signless to float conversion is not implemented") + elif isinstance(x.type, ir.FloatType): # type: ignore + if x.type != ty.mlir_type: + x = arith_helper.cvtf(x, ty.mlir_type, loc=loc, ip=ip) + super().__init__(x) + elif isinstance(x, Integer): + if isinstance(x.value, ir.Value): # type: ignore + x = arith_helper.itofp( + x.value, type(x).signed, ty.mlir_type, loc=loc, ip=ip + ) + else: + x = float(x.value) + super().__init__(x) + elif isinstance(x, Float): + Float.__init__(self, x.value) + else: + raise DSLRuntimeError(f"{x} to Float conversion is not supported") + + +class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T.bool): + """Boolean type representation in the DSL. + + This class represents boolean values in the DSL, with a width of 1 bit. + It supports conversion from various types to boolean values. + + :param a: Value to convert to Boolean + :type a: Union[bool, int, float, "Value", Numeric] + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :raises DSLRuntimeError: If the input value cannot be converted to Boolean + + Conversion rules: + + 1. Python bool/int/float: + - Converted using Python's bool() function + - Example: Boolean(1) -> True, Boolean(0) -> False + + 2. Numeric: + - Uses the Numeric.value to construct Boolean recursively + + 3. MLIR Value with IntegerType: + - If width is 1: Direct assignment + - Otherwise: Compares with 0 using arith.cmpi + + 4. MLIR Value with FloatType: + - Compares with 0.0 using arith.cmpf + - Uses unordered comparison to handle NaN values + """ + + def __init__( + self, a: Union[bool, int, float, ir.Value, Numeric], *, loc=None, ip=None + ): + value = None + if isinstance(a, (bool, int, float)): + value = bool(a) + elif isinstance(a, Numeric): + Boolean.__init__(self, a.value, loc=loc, ip=ip) + return + elif isinstance(a, ArithValue): + if a.type == T.bool(): + value = a + else: + value = a != arith_helper.const(0, a.type, loc=loc, ip=ip) + if value is None: + raise DSLRuntimeError(f"Cannot convert {a} to Boolean") + super().__init__(value, loc=loc, ip=ip) + self._value_int8 = None + + def ir_value_int8(self, *, loc=None, ip=None): + """ + Returns int8 ir value of Boolean. + When we need to store Boolean tensor element, use ir_value_int8(). + + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :return: The int8 value of this Boolean + :rtype: ir.Value + """ + if self._value_int8 is not None: + return self._value_int8 + self._value_int8 = Int8(self.value, loc=loc, ip=ip).ir_value() + return self._value_int8 + + def __neg__(self, *, loc=None, ip=None): + """Negation operator is not supported for boolean type. + + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :raises TypeError: Always raises this error as negation is not supported + """ + raise TypeError("Negation, the operator `-` is not supported for boolean type") + + +class Int8(Integer, metaclass=IntegerMeta, width=8, signed=True, mlir_type=T.i8): ... + + +class Int16(Integer, metaclass=IntegerMeta, width=16, signed=True, mlir_type=T.i16): ... + + +class Int32(Integer, metaclass=IntegerMeta, width=32, signed=True, mlir_type=T.i32): ... + + +class Int64(Integer, metaclass=IntegerMeta, width=64, signed=True, mlir_type=T.i64): ... + + +class Int128( + Integer, metaclass=IntegerMeta, width=128, signed=True, mlir_type=lambda: T.i(128) +): ... + + +class Uint8(Integer, metaclass=IntegerMeta, width=8, signed=False, mlir_type=T.i8): ... + + +class Uint16( + Integer, metaclass=IntegerMeta, width=16, signed=False, mlir_type=T.i16 +): ... + + +class Uint32( + Integer, metaclass=IntegerMeta, width=32, signed=False, mlir_type=T.i32 +): ... + + +class Uint64( + Integer, metaclass=IntegerMeta, width=64, signed=False, mlir_type=T.i64 +): ... + + +class Uint128( + Integer, metaclass=IntegerMeta, width=128, signed=False, mlir_type=lambda: T.i(128) +): ... + + +class Float64(Float, metaclass=FloatMeta, width=64, mlir_type=T.f64): + def __c_pointers__(self): + if not isinstance(self.value, float): + raise ValueError("only float is supported") + + return [ + ctypes.cast(ctypes.pointer(ctypes.c_double(self.value)), ctypes.c_void_p) + ] + + +class Float32(Float, metaclass=FloatMeta, width=32, mlir_type=T.f32): + @staticmethod + def _get_c_pointer(value: float): + return ctypes.cast(ctypes.pointer(ctypes.c_float(value)), ctypes.c_void_p) + + def __c_pointers__(self): + if not isinstance(self.value, float): + raise ValueError("only float is supported") + + return [Float32._get_c_pointer(self.value)] + + +class TFloat32(Float, metaclass=FloatMeta, width=32, mlir_type=T.tf32): + def __c_pointers__(self): + if not isinstance(self.value, float): + raise ValueError("only float is supported") + return [Float32._get_c_pointer(self.value)] + + +class Float16(Float, metaclass=FloatMeta, width=16, mlir_type=T.f16): + @staticmethod + def _get_c_pointer(value: float): + # Convert float to float16 binary representation + # First convert to numpy float16 to handle the conversion + f16_val = np.float16(value) + # Get the raw bits as a 16-bit integer + bits = f16_val.view(np.uint16) + # Create a short (16-bit int) with those bits + c_val = ctypes.c_short(bits) + return ctypes.cast(ctypes.pointer(c_val), ctypes.c_void_p) + + def __c_pointers__(self): + if not isinstance(self.value, float): + raise ValueError("only float is supported") + return [Float16._get_c_pointer(self.value)] + + +class BFloat16(Float, metaclass=FloatMeta, width=16, mlir_type=T.bf16): + def __c_pointers__(self): + if not isinstance(self.value, float): + raise ValueError("only float is supported") + + return Float.__c_pointers__(self) + + +class Float8E5M2(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E5M2): ... + + +class Float8E4M3FN(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3FN): ... + + +class Float8E4M3B11FNUZ( + Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3B11FNUZ +): ... + + + +# Added missing float types +class Float8E4M3(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3): ... + + +class Float8E8M0FNU(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E8M0FNU): ... + + +class Float4E2M1FN(Float, metaclass=FloatMeta, width=4, mlir_type=T.f4E2M1FN): ... + + +class Float6E3M2FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E3M2FN): ... + + +class Float6E2M3FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E2M3FN): ... + + +_unsupported_dst_float_types = [ + Float8E4M3, + Float8E4M3B11FNUZ, + Float4E2M1FN, + Float6E3M2FN, + Float6E2M3FN, +] + + +ALL_DTYPES = { + Int8, + Int16, + Int32, + Int64, + Int128, + Uint8, + Uint16, + Uint32, + Uint64, + Uint128, + BFloat16, + Float16, + Float32, + TFloat32, + Float64, + Float8E5M2, + Float8E4M3, + Float8E4M3FN, + Float8E8M0FNU, + Float8E4M3B11FNUZ, + Float4E2M1FN, + Float6E2M3FN, + Float6E3M2FN, +} +__STR_TO_DTYPE__ = {dt.__name__: dt for dt in ALL_DTYPES} + + +def dtype(dtype_) -> Type[Numeric]: + t = None + if const_expr(isinstance(dtype_, str) and dtype_ in __STR_TO_DTYPE__): + t = __STR_TO_DTYPE__[dtype_] + else: + raise TypeError(f"can't interpret {dtype_} as data type") + + return t + + +############################################################## +# Tensor +############################################################## + + +class TensorMeta(DslType): + _element_type = Any + _shape = Any + + """ + Examples: + >>> Tensor[Int32, (3,)] + >>> Tensor[Float32, (3, 4)] + >>> T = TypeVar("T") + >>> Tensor[T, (3, 4, 5)] + """ + + def __new__(cls, name, bases, attrs, element_type=Any, shape=Any): + new_cls = super().__new__(cls, name, bases, attrs) + new_cls._element_type = element_type + new_cls._shape = shape + return new_cls + + +# Generic type +TY = TypeVar("TY") + + +class Constexpr(Generic[TY]): + """Value is passed and computed by python interpreter""" + + pass + + +class align: + def __init__(self, value: int): + if value <= 0 or (value & (value - 1)) != 0: + raise DSLRuntimeError("expects align be power of 2 as positive value") + self._value = value + + def __str__(self): + return f"align({self._value})" + + +class PointerMeta(DslType): + def __new__(cls, name, bases, attrs, value_type=Int32, align_=align(1)): + new_cls = super().__new__( + cls, + name, + bases, + attrs, + mlir_type=lambda: getattr(ir, "UnrankedMemRefType").get( + value_type.mlir_type, getattr(ir, "Attribute").parse("0") + ), + ) + new_cls._value_type = value_type + new_cls._align = align_ + return new_cls + + def __eq__(cls, other): + if not isinstance(other, PointerMeta): + return False + return ( + cls._value_type == other._value_type + and cls._align._value == other._align._value + ) # Compare alignment values + + def __hash__(cls): + return hash((cls._value_type, cls._align._value)) # Hash alignment value + + def __getitem__(cls, params) -> Type["Pointer"]: + value_type, align_ = params + + if not isinstance(align_, align): + raise DSLRuntimeError(f"expects align but got {align_}") + + # Create new class with proper name and parameters + new_cls = type( + f"Pointer[{value_type.__name__}, {align_}]", + (Pointer,), + {}, + value_type=value_type, + align_=align_, # Pass alignment to __new__ + ) + return new_cls + + def __str__(cls): + return f"ptr<{cls._value_type}, {cls._align}>" + + +class Pointer(metaclass=PointerMeta): + """ + A pointer to a memory location. + + Examples: + + def foo(a : Pointer[Int32, align=8]): + ... + + """ + + def __init__(self, value): + self.value = value + + def __str__(self): + return f"{self.value} : {type(self)}" + + +class IRConst(Generic[TY]): + """Value is passed as MLIR constant value for (arith.constant).""" + + def __init__(self, ty: TY): + self.ty = ty + + +class IRValue(Generic[TY]): + """Value is passed as MLIR dynamic value.""" + + def __init__(self, ty: TY): + self.ty = ty + + +class IRVariadic: + """ + A helper class to pass a variadic number of arguments to a function. + """ + + def __init__(self, operands): + """ + Create a list of variadic operands. `operands` must be dynamic values. + """ + self.operands = operands + + def block_arg_types(self): + """ + Return the list of block args types. + """ + return [operand.type for operand in self.operands] + + def set_func_args(self, block_args): + """ + This function is called after entering a function. `block_args` are the + block arguments that correspond to the passed operands. Derived classes + may implement this function to provide convenience getters for block + arguments. + """ + pass + + def __len__(self): + """ + Return the length of variadic operands. + """ + return len(self.operands) + + +class FuncArgWithAttr(IRValue): + """ + This derived class is specifically for func op arg with attr + """ + + def __init__(self, ty, attr_name, attr_ty, attr_value=None): + super().__init__(ty) + assert attr_name is not None and ( + attr_ty is not None or attr_value is not None + ), "Invalid attr_name and/or attr_ty and/or attr_value for FuncArgWithAttr" + self.attr_name = attr_name + self.attr_ty = attr_ty + self.attr_value = attr_value + + + +def implicitDowncastNumericType(value): + if isinstance(value, Numeric): + return value.ir_value() + return value + + +__all__ = [ + "DslType", + "Numeric", + "NumericMeta", + "IntegerMeta", + "FloatMeta", + "Boolean", + "Integer", + "Int16", + "Int32", + "Int64", + "Int128", + "Int8", + "Uint8", + "Uint16", + "Uint32", + "Uint64", + "Uint128", + "Float", + "Float16", + "BFloat16", + "TFloat32", + "Float32", + "Float64", + "Float8E5M2", + "Float8E4M3", + "Float8E4M3FN", + "Float8E4M3B11FNUZ", + "Float8E4M3", + "Float8E8M0FNU", + "Float4E2M1FN", + "Float6E2M3FN", + "Float6E3M2FN", + "as_numeric", + "align", + "Pointer", + "dtype", + "Constexpr", + "IRConst", + "IRValue", + "IRVariadic", + "implicitDowncastNumericType", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bfb2b7d91ee72b04a89de59e7dfbdec2be646c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from . import stacktrace +from . import logger +from . import timer +__all__ = [ + "logger", + "timer", + "stacktrace", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e4b4edf359ec86b6b5806cb0b2296f9cb918f6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides logging helper functions +""" + +import logging + +logger = None + + +def log(): + return logger + + +def setup_log( + name, log_to_console=False, log_to_file=False, log_file_path=None, log_level=1 +): + """Set up and configure a logger with console and/or file handlers. + + :param name: Name of the logger to create + :type name: str + :param log_to_console: Whether to enable logging to console, defaults to False + :type log_to_console: bool, optional + :param log_to_file: Whether to enable logging to file, defaults to False + :type log_to_file: bool, optional + :param log_file_path: Path to the log file, required if log_to_file is True + :type log_file_path: str, optional + :param log_level: Logging level to set, defaults to 1 + :type log_level: int, optional + :raises ValueError: If log_to_file is True but log_file_path is not provided + :return: Configured logger instance + :rtype: logging.Logger + """ + # Create a custom logger + global logger + logger = logging.getLogger(name) + if log_to_console or log_to_file: + logger.setLevel(log_level) + else: + # Makes sure logging is OFF + logger.setLevel(logging.CRITICAL + 1) + + # Clear existing handlers to prevent duplicate logs + if logger.hasHandlers(): + logger.handlers.clear() + + # Define formatter + formatter = logging.Formatter( + f"%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s] - %(message)s" + ) + + # Add console handler if enabled + if log_to_console: + console_handler = logging.StreamHandler() + console_handler.setLevel(log_level) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # Add file handler if enabled + if log_to_file: + if not log_file_path: + raise ValueError("log_file_path must be provided when enable_file is True") + file_handler = logging.FileHandler(log_file_path) + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +logger = setup_log("generic") diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py new file mode 100644 index 0000000000000000000000000000000000000000..d2091098c173e8a941ed7958802dfbdee24199bc --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" + This module provides stacktrace helper functions +""" + +import os +import re + + +def walk_to_top_module(start_path): + """ + Walk up from the start_path to find the top-level Python module. + + :param start_path: The path to start from. + :return: The path of the top-level module. + """ + current_path = start_path + + while True: + # Check if we are at the root directory + if os.path.dirname(current_path) == current_path: + break + + # Check for __init__.py + init_file_path = os.path.join(current_path, "__init__.py") + if os.path.isfile(init_file_path): + # If __init__.py exists, move up one level + current_path = os.path.dirname(current_path) + else: + # If no __init__.py, we are not in a module; stop + break + + # If we reached the root without finding a module, return None + if os.path.dirname(current_path) == current_path and not os.path.isfile( + os.path.join(current_path, "__init__.py") + ): + return None + + # Return the path of the top-level module + return current_path + + +def _filter_internal_frames(traceback, internal_path): + """ + Filter out stack frames from the traceback that belong to the specified module path. + + This function removes stack frames from the traceback whose file paths start with + the given prefix_path, effectively hiding internal implementation details from + the error traceback shown to users. + """ + iter_prev = None + iter_tb = traceback + while iter_tb is not None: + if os.path.abspath(iter_tb.tb_frame.f_code.co_filename).startswith( + internal_path + ): + if iter_tb.tb_next: + if iter_prev: + iter_prev.tb_next = iter_tb.tb_next + else: + traceback = iter_tb.tb_next + else: + iter_prev = iter_tb + iter_tb = iter_tb.tb_next + return traceback + + +_generated_function_names = re.compile( + r"^(loop_body|while_region|while_before_block|while_after_block|if_region|then_block|else_block|elif_region)_\d+$" +) + + +def _filter_duplicated_frames(traceback): + """ + Filter out duplicated stack frames from the traceback. + The function filters out consecutive frames that are in the same file and have the same line number. + In a sequence of consecutive frames, the logic prefers to keep the non-generated frame or the last frame. + """ + iter_prev = None + iter_tb = traceback + while iter_tb is not None: + skip_current = False + skip_next = False + if iter_tb.tb_next: + current_filename = os.path.abspath(iter_tb.tb_frame.f_code.co_filename) + next_filename = os.path.abspath(iter_tb.tb_next.tb_frame.f_code.co_filename) + # if in the same file, check if the line number is the same + if current_filename == next_filename: + current_lineno = iter_tb.tb_lineno + next_lineno = iter_tb.tb_next.tb_lineno + if current_lineno == next_lineno: + # Same file and line number, check name, if current is generated, skip current, otherwise skip next + name = iter_tb.tb_frame.f_code.co_name + is_generated = bool(_generated_function_names.match(name)) + if is_generated: + # Skip current + skip_current = True + else: + # Skip next if it's generated, otherwise keep both + next_name = iter_tb.tb_next.tb_frame.f_code.co_name + skip_next = bool(_generated_function_names.match(next_name)) + if skip_current: + if iter_prev: + iter_prev.tb_next = iter_tb.tb_next + else: + traceback = iter_tb.tb_next + elif skip_next: + # if next is last frame, don't skip + if iter_tb.tb_next.tb_next: + iter_tb.tb_next = iter_tb.tb_next.tb_next + iter_prev = iter_tb + else: + iter_prev = iter_tb + iter_tb = iter_tb.tb_next + + return traceback + + +def filter_stackframe(traceback, prefix_path): + """ + Filter out stack frames from the traceback that belong to the specified module path. + + This function removes stack frames from the traceback whose file paths start with + the given prefix_path, effectively hiding internal implementation details from + the error traceback shown to users. + + :param traceback: The traceback object to filter. + :param prefix_path: The path prefix to filter out from the traceback. + :return: The filtered traceback with internal frames removed. + """ + # Step 1: filter internal frames + traceback = _filter_internal_frames(traceback, prefix_path) + + # Step 2: consolidate duplicated frames + return _filter_duplicated_frames(traceback) + + +def filter_exception(value, module_dir): + """ + Filter out internal implementation details from exception traceback. + + This function recursively processes an exception and its cause chain, + removing stack frames that belong to the specified module directory. + This helps to present cleaner error messages to users by hiding + implementation details. + + :param value: The exception object to filter. + :param module_dir: The module directory path to filter out from tracebacks. + :return: The filtered exception with internal frames removed. + """ + if hasattr(value, "__cause__") and value.__cause__: + filter_exception(value.__cause__, module_dir) + + if hasattr(value, "__traceback__"): + filter_stackframe(value.__traceback__, module_dir) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..f41d3f7410c0227ff1b1f8df4b8ce14557cf649b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides a timing helper functions +""" +from functools import wraps + +from .logger import log + + +# TODO: revisit this part when mlir timing manager is ready for pybind. +def timer(*dargs, **kwargs): + enable = kwargs.get("enable", True) + + def decorator(func): + @wraps(func) + def func_wrapper(*args, **kwargs): + if not enable: + return func(*args, **kwargs) + from time import time + + start = time() + result = func(*args, **kwargs) + end = time() + + # Convert time from seconds to us + spend_us = (end - start) * 1e6 + + # Determine the function type and format the log message + if hasattr(func, "__name__"): + func_name = func.__name__ + log_message = f"[JIT-TIMER] Function: {func_name} | Execution Time: {spend_us:.2f} µs" + elif "CFunctionType" in str(type(func)): + log_message = f"[JIT-TIMER] C API Function: {str(func)} | Execution Time: {spend_us:.2f} µs" + else: + log_message = f"[JIT-TIMER] Anonymous Function | Execution Time: {spend_us:.2f} µs" + + log().info(log_message) + + return result + + return func_wrapper + + if len(dargs) == 1 and callable(dargs[0]): + return decorator(dargs[0]) + else: + return decorator diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c7ed2607675990ad9579fa06b25935b2ccb46e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .cutlass_dsl import ( + Constexpr, + as_numeric, + min, + max, + and_, + or_, + all_, + any_, + not_, + all_, + any_, + select_, + # Control-flow without AST pre-processor + if_generate, + for_generate, + LoopUnroll, + while_generate, + yield_out, + # Control-flow with AST pre-processor + range_constexpr, + range_dynamic, + const_expr, + dynamic_expr, + # Data types + dtype, # Provides conversions to types inheriting from NumericType + DSLRuntimeError, + JitArgAdapterRegistry, + # Construction utilities for user-defined classes + extract_mlir_values, + new_from_mlir_values, +) + +from .cute.typing import * + +# Utilities not belonging to CuTe +from . import utils as utils + +# Used as internal symbol +from . import cutlass_dsl as _dsl + +# Aliases +LaunchConfig = _dsl.BaseDSL.LaunchConfig +register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter +gpu = _dsl.cutlass_gpu +cuda = _dsl.cuda_helpers + +CACHE_FILE = "compiled_cache.db" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8702ed9163837925057b48f9aafd11cffbb26a7e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +# Use the auto-generated enum AddressSpace +from cutlass._mlir.dialects.cute import AddressSpace + +# Explicitly import types that might be directly used by other modules. +# This is a fix for using Sphinx to generate documentation +# Because Sphinx processes each module in isolation, it won't be able to rely +# on re-exported symbols via wildcard imports (from .typing import *) in the +# same way that Python does at runtime. +from .typing import ( + Shape, + Stride, + IntTuple, + Coord, + Tile, + XTuple, + Tiler, + Layout, + Pointer, + Tensor, +) + +# Import everything else +from .typing import * + +from .core import ( + assume, + is_integer, + is_int_tuple, + is_static, + size, + has_underscore, + slice_, + make_ptr, + make_layout, + recast_layout, + make_fragment_like, + depth, + rank, + flatten_to_tuple, + flatten, + unflatten, + product, + product_like, + shape, + size_in_bytes, + make_identity_layout, + make_ordered_layout, + make_composed_layout, + make_layout_tv, + make_swizzle, + recast_ptr, + make_tensor, + make_identity_tensor, + make_fragment, + recast_tensor, + get, + select, + front, + is_major, + leading_dim, + find, + find_if, + coalesce, + group_modes, + cosize, + dice, + product_each, + prepend, + append, + prepend_ones, + append_ones, + ceil_div, + slice_and_offset, + crd2idx, + domain_offset, + elem_less, + transform_leaf, + filter_zeros, + filter, + tile_to_shape, + shape_div, + composition, + complement, + right_inverse, + left_inverse, + max_common_layout, + max_common_vector, + logical_product, + zipped_product, + tiled_product, + flat_product, + raked_product, + blocked_product, + flat_divide, + logical_divide, + zipped_divide, + tiled_divide, + local_partition, + local_tile, + printf, + print_tensor, + # tiled mma/tiled copy + make_mma_atom, + make_tiled_mma, + make_copy_atom, + make_tiled_copy_tv, + make_tiled_copy, + make_tiled_copy_S, + make_tiled_copy_D, + make_tiled_copy_A, + make_tiled_copy_B, + make_tiled_copy_C, + make_tiled_copy_C_atom, + basic_copy, + basic_copy_if, + autovec_copy, + copy, + copy_atom_call, + gemm, + # Wrapper classes + ComposedLayout, + Swizzle, + E, + Atom, + MmaAtom, + CopyAtom, + TiledCopy, + TiledMma, + TensorSSA, + ReductionOp, + full, + full_like, + empty_like, + ones_like, + zeros_like, + where, + any_, + all_, + # User defined struct + struct, + pretty_str, + make_layout_image_mask, + repeat_like, + round_up, + is_congruent, + is_weakly_congruent, + ScaledBasis, + get_divisibility, + Ratio, +) + +from . import arch +from . import nvgpu +from . import testing +from . import runtime + +# Export all math ops without "math." +from .math import * + +# Used as internal symbol +from .. import cutlass_dsl as _dsl + +# Aliases +jit = _dsl.CuTeDSL.jit +kernel = _dsl.CuTeDSL.kernel +register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter +compile = _dsl.compile + +# Explicitly export all symbols for documentation generation +__all__ = [ + # Core types + "AddressSpace", + "Tensor", + "Layout", + "ComposedLayout", + "Swizzle", + "E", + "Atom", + "MmaAtom", + "CopyAtom", + "TiledCopy", + "TiledMma", + "TensorSSA", + # Basic utility functions + "assume", + "is_integer", + "is_int_tuple", + "is_static", + "size", + "has_underscore", + "slice_", + "depth", + "rank", + "shape", + "printf", + "print_tensor", + "pretty_str", + # Layout functions + "make_layout", + "recast_layout", + "make_identity_layout", + "make_ordered_layout", + "make_composed_layout", + "make_layout_tv", + "make_layout_image_mask", + # Tensor functions + "make_ptr", + "make_tensor", + "make_identity_tensor", + "make_fragment", + "make_fragment_like", + "recast_ptr", + "recast_tensor", + # Tensor manipulation + "get", + "select", + "front", + "is_major", + "leading_dim", + "find", + "find_if", + "coalesce", + "group_modes", + "cosize", + "size_in_bytes", + # Tuple operations + "flatten_to_tuple", + "flatten", + "product", + "product_like", + "product_each", + "prepend", + "append", + "prepend_ones", + "append_ones", + # Math operations + "ceil_div", + "round_up", + # Layout operations + "slice_and_offset", + "crd2idx", + "domain_offset", + "elem_less", + "filter_zeros", + "filter", + "tile_to_shape", + "shape_div", + "dice", + # Layout algebra + "composition", + "complement", + "right_inverse", + "left_inverse", + "max_common_layout", + "max_common_vector", + "is_congruent", + "is_weakly_congruent", + # Product operations + "logical_product", + "zipped_product", + "tiled_product", + "flat_product", + "raked_product", + "blocked_product", + # Division operations + "flat_divide", + "logical_divide", + "zipped_divide", + "tiled_divide", + "local_partition", + "local_tile", + # MMA and Copy operations + "make_mma_atom", + "make_tiled_mma", + "make_copy_atom", + "make_tiled_copy_tv", + "make_tiled_copy", + "make_tiled_copy_C_atom", + "basic_copy", + "basic_copy_if", + "autovec_copy", + "copy", + "copy_atom_call", + "gemm", + # Tensor creation + "full", + "full_like", + "empty_like", + "ones_like", + "zeros_like", + "where", + "any_", + "all_", + "repeat_like", + "ScaledBasis", + # User defined struct + "struct", + # Modules + "arch", + "nvgpu", + "testing", + "runtime", + # Decorators and code generation + "jit", + "kernel", + "register_jit_arg_adapter", + "compile", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01198215f74b07f224b1d5e53ff37075775bb201 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .elect import * +from .mbar import * +from .nvvm_wrappers import * +from .smem import * +from .tmem import * + +# __all__ is required here for documentation generation +__all__ = [ + # + # elect.py + # + "make_warp_uniform", + "elect_one", + # + # mbar.py + # + "mbarrier_init", + "mbarrier_init_fence", + "mbarrier_arrive_and_expect_tx", + "mbarrier_expect_tx", + "mbarrier_wait", + "mbarrier_try_wait", + "mbarrier_conditional_try_wait", + "mbarrier_arrive", + # + # nvvm_wrappers.py + # + "lane_idx", + "warp_idx", + "thread_idx", + "block_dim", + "block_idx", + "grid_dim", + "cluster_idx", + "cluster_dim", + "block_in_cluster_idx", + "block_in_cluster_dim", + "block_idx_in_cluster", + "shuffle_sync", + "shuffle_sync_up", + "shuffle_sync_down", + "shuffle_sync_bfly", + "barrier", + "barrier_arrive", + "sync_threads", + "sync_warp", + "fence_acq_rel_cta", + "fence_acq_rel_cluster", + "fence_acq_rel_gpu", + "fence_acq_rel_sys", + "cp_async_commit_group", + "cp_async_wait_group", + "cp_async_bulk_commit_group", + "cp_async_bulk_wait_group", + "cluster_wait", + "cluster_arrive", + "cluster_arrive_relaxed", + "fence_proxy", + "vote_ballot_sync", + "popc", + "fence_view_async_tmem_load", + "fence_view_async_tmem_store", + "warpgroup_reg_alloc", + "warpgroup_reg_dealloc", + "fma_packed_f32x2", + "mul_packed_f32x2", + "add_packed_f32x2", + "fmax", + "rcp_approx", + "exp2", + # Constants + "WARP_SIZE", + # Forward from auto-generated nvvm python + "ProxyKind", + "SharedSpace", + "RoundingModeKind", + # + # smem.py + # + "alloc_smem", + "get_dyn_smem", + "get_dyn_smem_size", + # + # tmem.py + # + "retrieve_tmem_ptr", + "alloc_tmem", + "relinquish_tmem_alloc_permit", + "dealloc_tmem", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py new file mode 100644 index 0000000000000000000000000000000000000000..ead552afab7de50a62f95eee7b4d8a2d9b4dfca9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op + +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects import nvvm, scf +from cutlass._mlir import ir + +from ..typing import Int, Int32 +from ...impl_utils import check_value_in + + +@dsl_user_op +def make_warp_uniform(value: Int, *, loc=None, ip=None) -> Int32: + """ + Creates a warp-uniform value from the given integer input. + + :param value: The integer to make warp uniform. + :type value: Int + :return: The warp-uniform value equal to the input. + :rtype: Int32 + """ + return Int32( + _cute_nvgpu_ir.arch_make_warp_uniform( + Int32(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) + + +class IfOpRegion: + """ + A context manager for if Op. + Automatically inserts `scf.yield([])` when exiting the context. + """ + + def __init__(self, block, *, loc=None, ip=None): + self.block = block + self.insert_point = ir.InsertionPoint(self.block) + self.loc = loc + self.ip = ip + + def __enter__(self): + self.insert_point.__enter__() + return self.block.arguments + + def __exit__(self, exc_type, exc_value, traceback): + scf.yield_([], loc=self.loc, ip=self.ip) + self.insert_point.__exit__(exc_type, exc_value, traceback) + + +@dsl_user_op +def elect_one(*, loc=None, ip=None) -> IfOpRegion: + """ + Elects one thread within a warp. + + .. code-block:: python + + with elect_one(): + # Only one thread in the warp executes the code in this context + pass + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + is_thread_leader = nvvm.elect_sync(T.bool()) + if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip) + return IfOpRegion(if_op.then_block, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py new file mode 100644 index 0000000000000000000000000000000000000000..80cb7b0b5fc6e226a39d68197382cbde2e32861d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py @@ -0,0 +1,349 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. +from typing import Optional + +from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op + +from cutlass._mlir.dialects import nvvm +from cutlass._mlir import ir + +from ..typing import Pointer, Int, Boolean, Int32 +from ...impl_utils import check_value_in + + +#################################################################################################### +# +# Mbarrier management utilities +# +#################################################################################################### + + +@dsl_user_op +def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None: + """ + Initializes a mbarrier with the specified thread arrival count. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param cnt: The arrival count of the mbarrier + :type cnt: Int + """ + nvvm.mbarrier_init_shared( + mbar_ptr.llvm_ptr, Int32(cnt).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + + +@dsl_user_op +def mbarrier_init_fence(*, loc=None, ip=None) -> None: + """ + A fence operation that applies to the mbarrier initializations. + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + nvvm.fence_mbarrier_init(loc=loc, ip=ip) + + +@dsl_user_op +def mbarrier_arrive_and_expect_tx( + mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None +) -> None: + """ + Arrives on a mbarrier and expects a specified number of transaction bytes. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param bytes: The number of transaction bytes + :type bytes: Int + :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to + the mbarrier is converted to a remote address in the peer CTA's + SMEM. + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + mbar_llvm_ptr = mbar_ptr.llvm_ptr + if peer_cta_rank_in_cluster is not None: + mbar_llvm_ptr = nvvm.mapa_shared_cluster( + mbar_llvm_ptr.type, + mbar_llvm_ptr, + Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + space = nvvm.MBarrierSpaceKind.CLUSTER + else: + space = nvvm.MBarrierSpaceKind.CTA + + nvvm.mbarrier_txn( + mbar_llvm_ptr, + Int32(bytes).ir_value(loc=loc, ip=ip), + kind=nvvm.MBarrierTxnKind.ARRIVE_EXPECT_TX, + space=space, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mbarrier_expect_tx( + mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None +) -> None: + """ + Expects a specified number of transaction bytes without an arrive. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param bytes: The number of transaction bytes + :type bytes: Int + :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to + the mbarrier is converted to a remote address in the peer CTA's + SMEM. + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + mbar_llvm_ptr = mbar_ptr.llvm_ptr + if peer_cta_rank_in_cluster is not None: + mbar_llvm_ptr = nvvm.mapa( + mbar_llvm_ptr.type, + mbar_llvm_ptr, + Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + space = nvvm.MBarrierSpaceKind.CLUSTER + else: + space = nvvm.MBarrierSpaceKind.CTA + + nvvm.mbarrier_txn( + mbar_llvm_ptr, + Int32(bytes).ir_value(loc=loc, ip=ip), + kind=nvvm.MBarrierTxnKind.EXPECT_TX, + space=space, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None: + """ + Waits on a mbarrier with a specified phase. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param phase: The phase to wait for (either 0 or 1) + :type phase: Int + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + timeout_ns = 10000000 + # This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX + # The timeout in ns only applies to the latter and this call is truly blocking + nvvm.mbarrier_try_wait_parity_shared( + mbar_ptr.llvm_ptr, + Int32(phase).ir_value(loc=loc, ip=ip), + Int32(timeout_ns).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Boolean: + """ + Attempts to wait on a mbarrier with a specified phase in a non-blocking fashion. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param phase: The phase to wait for (either 0 or 1) + :type phase: Int + :return: A boolean value indicating whether the wait operation was successful + :rtype: Boolean + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + return Boolean( + nvvm.mbarrier_wait_parity( + T.bool(), + mbar_ptr.llvm_ptr, + Int32(phase).ir_value(loc=loc, ip=ip), + nvvm.MBarrierWaitKind.TRY, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def mbarrier_conditional_try_wait( + cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None +) -> Boolean: + """ + Conditionally attempts to wait on a mbarrier with a specified phase in a non-blocking fashion. + + :param cond: A boolean predicate + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param phase: The phase to wait for (either 0 or 1) + :type phase: Int + :return: A boolean value indicating whether the wait operation was successful + :rtype: Boolean + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + return if_generate( + cond, + lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip), + lambda: Boolean(True).ir_value(loc=loc, ip=ip), + None, + [Boolean], + ) + + +@dsl_user_op +def mbarrier_arrive( + mbar_ptr: Pointer, + peer_cta_rank_in_cluster: Optional[Int] = None, + *, + loc=None, + ip=None, +) -> None: + """ + Arrives on an mbarrier. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to + the mbarrier is converted to a remote address in the peer CTA's + SMEM. + """ + mbar_llvm_ptr = mbar_ptr.llvm_ptr + if peer_cta_rank_in_cluster is not None: + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + mbar_llvm_ptr = nvvm.mapa_shared_cluster( + mbar_llvm_ptr.type, + mbar_llvm_ptr, + Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + space = nvvm.MBarrierSpaceKind.CLUSTER + else: + space = nvvm.MBarrierSpaceKind.CTA + + nvvm.mbarrier_txn( + mbar_llvm_ptr, + Int32(1).ir_value(loc=loc, ip=ip), + kind=nvvm.MBarrierTxnKind.ARRIVE, + space=space, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> None: + """ + Arrives on an mbarrier for async load **without incrementing** the arrival count + (`cp.async.mbarrier.arrive.shared ..., noinc=1`). + Used in the warp-specialized kernel when the non-TMA load warp(producer) is not the same + as the math/epilogue warp(consumer). + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + mbar_llvm_ptr = mbar_ptr.llvm_ptr + nvvm.cp_async_mbarrier_arrive_shared( + mbar_llvm_ptr, + noinc=True, + loc=loc, + ip=ip, + ) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..69e3b8acb1fd0d1bc6615cd835235c0bbd62027b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py @@ -0,0 +1,681 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from functools import partial +from typing import Optional, Tuple, Union, Callable +from typing_extensions import deprecated + +from cutlass.cutlass_dsl import T, dsl_user_op + +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm, nvvm, vector + +# Forward nvvm enums +from cutlass._mlir.dialects.nvvm import ( + ProxyKind, + SharedSpace, + Tcgen05WaitKind, + SetMaxRegisterAction, + RoundingModeKind, +) + +from ..typing import ( + Int, + Boolean, + Int16, + Uint16, + Int32, + Uint32, + Int64, + Float32, + BFloat16, + Numeric, + as_numeric, +) + +WARP_SIZE = 32 +FULL_MASK = 0xFFFFFFFF + + +@dsl_user_op +def lane_idx(*, loc=None, ip=None) -> Int32: + """ + Returns the lane index of the current thread within the warp. + """ + return Int32(nvvm.read_ptx_sreg_laneid(T.i32(), loc=loc, ip=ip)) + + +@dsl_user_op +def warp_idx(*, loc=None, ip=None) -> Int32: + """ + Returns the warp index within a CTA. + """ + warp_size = 32 + tid_x = Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip)) + tid_y = Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip)) + tid_z = Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip)) + ntid_x = Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip)) + ntid_y = Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip)) + tid = tid_x + tid_y * ntid_x + tid_z * ntid_x * ntid_y + return tid // warp_size + + +@dsl_user_op +def thread_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the thread index within a CTA. + """ + return ( + Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def block_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the number of threads in each dimension of the CTA. + """ + return ( + Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_ntid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def block_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the CTA identifier within a grid. + """ + return ( + Int32(nvvm.read_ptx_sreg_ctaid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_ctaid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_ctaid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def grid_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the number of CTAs in each dimension of the grid. + """ + return ( + Int32(nvvm.read_ptx_sreg_nctaid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_nctaid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_nctaid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the cluster identifier within a grid. + """ + return ( + Int32(nvvm.read_ptx_sreg_clusterid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_clusterid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_clusterid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the number of clusters in each dimension of the grid. + """ + return ( + Int32(nvvm.read_ptx_sreg_nclusterid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_nclusterid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_nclusterid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def block_in_cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the CTA index within a cluster across all dimensions. + """ + return ( + Int32(nvvm.read_ptx_sreg_cluster_ctaid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_cluster_ctaid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_cluster_ctaid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def block_in_cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the dimensions of the cluster. + """ + return ( + Int32(nvvm.read_ptx_sreg_cluster_nctaid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_cluster_nctaid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_cluster_nctaid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def block_idx_in_cluster(*, loc=None, ip=None) -> Int32: + """ + Returns the linearized identifier of the CTA within the cluster. + """ + return Int32(nvvm.read_ptx_sreg_cluster_ctarank(T.i32(), loc=loc, ip=ip)) + + +@dsl_user_op +def shuffle_sync_op( + value: Numeric, + offset: Int, + mask: Int = FULL_MASK, + mask_and_clamp: Int = WARP_SIZE - 1, + kind: nvvm.ShflKind = nvvm.ShflKind.idx, + *, + loc=None, + ip=None, +) -> Numeric: + """ + Shuffles a value within the threads of a warp. + + :param value: The value to shuffle + :type value: Numeric + :param mask: A mask describing the threads participating in this operation + :type mask: Int + :param offset: A source lane or a source lane offset depending on kind + :type offset: Int + :param mask_and_clamp: An integer containing two packed values specifying a mask for logically + splitting warps into sub-segments and an upper bound for clamping the + source lane index. + :type mask_and_clamp: Int + :param kind: The kind of shuffle, can be idx, up, down, or bfly + :type kind: ShflKind + :return: The shuffled value + :rtype: Numeric + """ + if not isinstance(value, Numeric): + value = as_numeric(value) + if value.width > 64: + raise ValueError("shuffle_sync only supports values up to 64 bits") + + orig_type = type(value) + if value.width < 32: + if value.dtype.is_float: + value = value.to(Float32) + else: + if value.signed: + value = value.to(Int32) + else: + value = value.to(Uint32) + return orig_type( + nvvm.shfl_sync( + type(value).mlir_type, + Int32(mask).ir_value(loc=loc, ip=ip), + value.ir_value(loc=loc, ip=ip), + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + ) + elif value.width == 32: + return orig_type( + nvvm.shfl_sync( + type(value).mlir_type, + Int32(mask).ir_value(loc=loc, ip=ip), + value.ir_value(loc=loc, ip=ip), + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + ) + else: + if value.width != 64: + raise ValueError( + "shuffle_sync only supports 64 bits values when the bit width is larger than 32" + ) + value = llvm.bitcast( + T.i64(), value.to(ir.Value, loc=loc, ip=ip), loc=loc, ip=ip + ) + # extract low 32 bits + low_32_bits = llvm.trunc( + T.i32(), value, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip + ) + # extract high 32 bits + high_32_bits = llvm.lshr( + value, Int64(32).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + high_32_bits = llvm.trunc( + T.i32(), high_32_bits, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip + ) + + low_32_bits_shfl = nvvm.shfl_sync( + T.i32(), + Int32(mask).ir_value(loc=loc, ip=ip), + low_32_bits, + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + high_32_bits_shfl = nvvm.shfl_sync( + T.i32(), + Int32(mask).ir_value(loc=loc, ip=ip), + high_32_bits, + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + + # combine low and high 32 bits + low_64_bit = llvm.zext(T.i64(), low_32_bits_shfl, loc=loc, ip=ip) + high_64_bit = llvm.zext(T.i64(), high_32_bits_shfl, loc=loc, ip=ip) + shlf_res = llvm.shl( + high_64_bit, + Int64(32).ir_value(loc=loc, ip=ip), + llvm.IntegerOverflowFlags.none, + loc=loc, + ip=ip, + ) + shlf_res = llvm.or_(shlf_res, low_64_bit, loc=loc, ip=ip) + shlf_res = llvm.bitcast(orig_type.mlir_type, shlf_res, loc=loc, ip=ip) + return orig_type(shlf_res) + +shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx) +shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up) +shuffle_sync_down = partial(shuffle_sync_op, kind=nvvm.ShflKind.down) +shuffle_sync_bfly = partial(shuffle_sync_op, kind=nvvm.ShflKind.bfly) + + +@dsl_user_op +def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> None: + """ + Creates a barrier, optionally named. + """ + if barrier_id is not None: + barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip) + + if number_of_threads is not None: + number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip) + + nvvm.barrier( + barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip + ) + + +@dsl_user_op +def barrier_arrive( + *, barrier_id=None, number_of_threads=None, loc=None, ip=None +) -> None: + if barrier_id is not None: + barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip) + + if number_of_threads is None: + raise ValueError( + "barrier_arrive needs pass number_of_threads to arrive the barrier", + ) + number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip) + + nvvm.barrier_arrive( + barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip + ) + + +@dsl_user_op +def sync_threads(*, loc=None, ip=None) -> None: + """ + Synchronizes all threads within a CTA. + """ + nvvm.barrier(loc=loc, ip=ip) + + +@dsl_user_op +def sync_warp(mask: Int = FULL_MASK, *, loc=None, ip=None) -> None: + """ + Performs a warp-wide sync with an optional mask. + """ + nvvm.bar_warp_sync(Int32(mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + + +@dsl_user_op +def fence_acq_rel_cta(*, loc=None, ip=None) -> None: + """ + Fence operation with acquire-release semantics. + + See the `PTX documentation `__. + """ + nvvm.fence_acq_rel_cta(loc=loc, ip=ip) + + +@dsl_user_op +def fence_acq_rel_cluster(*, loc=None, ip=None) -> None: + """ + Fence operation with acquire-release semantics. + + See the `PTX documentation `__. + """ + nvvm.fence_acq_rel_cluster(loc=loc, ip=ip) + + +@dsl_user_op +def fence_acq_rel_gpu(*, loc=None, ip=None) -> None: + """ + Fence operation with acquire-release semantics. + + See the `PTX documentation `__. + """ + nvvm.fence_acq_rel_gpu(loc=loc, ip=ip) + + +@dsl_user_op +def fence_acq_rel_sys(*, loc=None, ip=None) -> None: + """ + Fence operation with acquire-release semantics. + + See the `PTX documentation `__. + """ + nvvm.fence_acq_rel_sys(loc=loc, ip=ip) + + +@dsl_user_op +def cp_async_commit_group(*, loc=None, ip=None) -> None: + """ + Commits all prior initiated but uncommitted cp.async instructions. + + See the `PTX documentation `__. + """ + nvvm.cp_async_commit_group(loc=loc, ip=ip) + + +@dsl_user_op +def cp_async_wait_group(n, *, loc=None, ip=None) -> None: + """ + Waits till only a specified numbers of cp.async groups are pending. + + See the `PTX documentation `__. + """ + nvvm.cp_async_wait_group(n, loc=loc, ip=ip) + + +@dsl_user_op +def cp_async_bulk_commit_group(*, loc=None, ip=None) -> None: + """ + Commits all prior initiated but uncommitted cp.async.bulk instructions. + + See the `PTX documentation `__. + """ + nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip) + + +@dsl_user_op +def cp_async_bulk_wait_group(group, *, read=None, loc=None, ip=None) -> None: + """ + Waits till only a specified numbers of cp.async.bulk groups are pending. + + See the `PTX documentation `__. + """ + nvvm.cp_async_bulk_wait_group(group, read=read, loc=loc, ip=ip) + + +@dsl_user_op +def cluster_wait(*, loc=None, ip=None) -> None: + """ + A cluster-wide wait operation. + """ + nvvm.cluster_wait(loc=loc, ip=ip) + + +@dsl_user_op +def cluster_arrive(*, aligned=None, loc=None, ip=None) -> None: + """ + A cluster-wide arrive operation. + """ + nvvm.cluster_arrive(aligned=aligned, loc=loc, ip=ip) + + +@dsl_user_op +def cluster_arrive_relaxed(*, aligned=None, loc=None, ip=None) -> None: + """ + A cluster-wide arrive operation with relaxed semantics. + """ + nvvm.cluster_arrive_relaxed(aligned=aligned, loc=loc, ip=ip) + + +@dsl_user_op +def fence_proxy( + kind: ProxyKind, + *, + space: Optional[SharedSpace] = None, + use_intrinsic=None, + loc=None, + ip=None, +) -> None: + nvvm.fence_proxy( + kind=kind, space=space, use_intrinsic=use_intrinsic, loc=loc, ip=ip + ) + + +@dsl_user_op +def vote_ballot_sync( + pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None +) -> Int32: + """ + Performs a ballot operation across the warp. + """ + return Int32( + nvvm.vote_ballot_sync( + T.i32(), + Int32(mask).ir_value(loc=loc, ip=ip), + Boolean(pred).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def popc(value: Numeric, *, loc=None, ip=None) -> Numeric: + """ + Performs a population count operation. + """ + if not isinstance(value, Numeric): + value = as_numeric(value) + return type(value)(llvm.intr_ctpop(value.ir_value(), loc=loc, ip=ip)) + + +@dsl_user_op +def fence_view_async_tmem_op( + kind: Tcgen05WaitKind, + *, + loc=None, + ip=None, +) -> None: + """ + Perform a fence operation on the async TMEM load or store. + + .. note:: + This function is only available on sm_100a and above. + The fence is required to synchronize the TMEM load/store + and let the pipeline release or commit the buffer. + + Take a mma2acc pipeline as an example of LOAD fence, the ACC tensor is from TMEM. + ``` + # Start to copy ACC from TMEM to register + cute.copy(tmem_load, tACC, rACC) + fence_view_async_tmem_load() + # After fence, we can ensure the TMEM buffer is consumed totally. + # Release the buffer to let the MMA know it can overwrite the buffer. + mma2accum_pipeline.consumer_release(curr_consumer_state) + ``` + Take a TS GEMM kernel as an example of STORE fence, the A tensor is from TMEM. + ``` + # Start to copy A from register to TMEM + cute.copy(tmem_store, rA, tA) + fence_view_async_tmem_store() + # After fence, we can ensure the TMEM buffer is ready. + # Commit the buffer to let the MMA know it can start to load A. + tmem_mma_pipeline.producer_commit(curr_producer_state) + ``` + + + :param kind: The kind of fence operation to perform including LOAD and STORE. + :type kind: Tcgen05WaitKind + """ + nvvm.tcgen05_wait(kind, loc=loc, ip=ip) + + +fence_view_async_tmem_load = partial( + fence_view_async_tmem_op, kind=Tcgen05WaitKind.LOAD +) +fence_view_async_tmem_store = partial( + fence_view_async_tmem_op, kind=Tcgen05WaitKind.STORE +) + + +@dsl_user_op +def warpgroup_reg_realloc_op( + reg_count: int, + kind: SetMaxRegisterAction, + *, + loc=None, + ip=None, +) -> None: + nvvm.setmaxregister(reg_count, kind, loc=loc, ip=ip) + + +warpgroup_reg_alloc = partial( + warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.increase +) +warpgroup_reg_dealloc = partial( + warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.decrease +) + + +@dsl_user_op +def calc_packed_f32x2_op( + src_a: Tuple[Float32, Float32], + src_b: Tuple[Float32, Float32], + src_c: Tuple[Float32, Float32] | None, + calc_func: Callable, + *, + rnd=RoundingModeKind.RZ, + ftz=True, + loc=None, + ip=None, +) -> Tuple[Float32, Float32]: + vec_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc) + vec_src_a = vector.from_elements( + vec_type, tuple(as_numeric(a).ir_value() for a in src_a), loc=loc, ip=ip + ) + vec_src_b = vector.from_elements( + vec_type, tuple(as_numeric(b).ir_value() for b in src_b), loc=loc, ip=ip + ) + if src_c is not None: + vec_src_c = vector.from_elements( + vec_type, tuple(as_numeric(c).ir_value() for c in src_c), loc=loc, ip=ip + ) + vec_res = calc_func( + vec_type, vec_src_a, vec_src_b, vec_src_c, rnd=rnd, ftz=ftz, loc=loc, ip=ip + ) + else: + vec_res = calc_func( + vec_type, vec_src_a, vec_src_b, rnd=rnd, ftz=ftz, loc=loc, ip=ip + ) + + res0 = Float32( + vector.extract( + vec_res, dynamic_position=[], static_position=[0], loc=loc, ip=ip + ) + ) + res1 = Float32( + vector.extract( + vec_res, dynamic_position=[], static_position=[1], loc=loc, ip=ip + ) + ) + return res0, res1 + + +fma_packed_f32x2 = partial(calc_packed_f32x2_op, calc_func=nvvm.fma_packed_f32x2) +mul_packed_f32x2 = partial( + calc_packed_f32x2_op, src_c=None, calc_func=nvvm.mul_packed_f32x2 +) +add_packed_f32x2 = partial( + calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2 +) + + +@dsl_user_op +def fmax( + a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None +) -> Float32: + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None): + return Float32( + nvvm.rcp_approx_ftz_f( + T.f32(), Float32(a).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) + + +@dsl_user_op +@deprecated( + "cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead" +) +def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "ex2.approx.ftz.f32 $0, $1;", + "=f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +@deprecated( + "cute.arch.exp is deprecated, use cute.math.exp with `fastmath=True` instead" +) +def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: + LOG2_E = 1.4426950408889634 + return exp2(a * LOG2_E, loc=loc, ip=ip) + + +@dsl_user_op +@deprecated( + "cute.arch.exp_packed_f32x2 is deprecated, use cute.arch.mul_packed_f32x2 and cute.math.exp2 with `fastmath=True` instead" +) +def exp_packed_f32x2( + a: Tuple[Float32, Float32], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: + LOG2_E = Float32(1.4426950408889634) + b = mul_packed_f32x2(a, (LOG2_E, LOG2_E), loc=loc, ip=ip) + return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py new file mode 100644 index 0000000000000000000000000000000000000000..37f87ea64d7f7482f3b2f464be6a0ee1a2e3494f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Optional, Type + +from cutlass.cutlass_dsl import T, dsl_user_op + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ..typing import Pointer, Numeric, NumericMeta + + +@dsl_user_op +def alloc_smem( + element_type: Type[Numeric], + size_in_elems: int, + alignment: Optional[int] = None, + *, + loc=None, + ip=None, +) -> Pointer: + """ + Statically allocates SMEM. + + :param element_type: The pointee type of the pointer. + :type element_type: Type[Numeric] + :param size_in_elems: The size of the allocation in terms of number of elements of the + pointee type + :type size_in_elems: int + :param alignment: An optional pointer alignment for the allocation + :type alignment: int + :return: A pointer to the start of the allocation + :rtype: Pointer + """ + if not isinstance(element_type, NumericMeta): + raise TypeError( + f"element_type must be a type of Numeric, but got {element_type}" + ) + + if alignment is None: + # Default alignment based on the element type's width + alignment = element_type.width // 8 + ptr_ty = _cute_ir.PtrType.get( + element_type.mlir_type, _cute_ir.AddressSpace.smem, alignment + ) + return _cute_nvgpu_ir.arch_alloc_smem( + ptr=ptr_ty, + input=ir.IntegerAttr.get(T.i32(), size_in_elems), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def get_dyn_smem( + element_type: Type[Numeric], + alignment: Optional[int] = None, + *, + loc=None, + ip=None, +) -> Pointer: + """ + Retrieves a pointer to a dynamic SMEM allocation. + + :param element_type: The pointee type of the pointer. + :type element_type: Type[Numeric] + :param alignment: An optional pointer alignment, the result pointer is offset appropriately + :type alignment: int + :return: A pointer to the start of the dynamic SMEM allocation with a correct + alignement + :rtype: Pointer + """ + if not isinstance(element_type, NumericMeta): + raise TypeError( + f"element_type must be a type of Numeric, but got {element_type}" + ) + + if alignment is None: + # Default alignment based on the element type's width + alignment = element_type.width // 8 + ptr_ty = _cute_ir.PtrType.get( + element_type.mlir_type, + _cute_ir.AddressSpace.smem, + alignment, + ) + return _cute_nvgpu_ir.arch_get_dyn_smem(ptr=ptr_ty, loc=loc, ip=ip) + + +@dsl_user_op +def get_dyn_smem_size(*, loc=None, ip=None) -> int: + """ + Gets the size in bytes of the dynamic shared memory that was specified at kernel launch time. + This can be used for bounds checking during shared memory allocation. + + :return: The size of dynamic shared memory in bytes + :rtype: int + """ + return _cute_nvgpu_ir.arch_get_dyn_smem_size(loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py new file mode 100644 index 0000000000000000000000000000000000000000..302616d20b34ccfe1d3194e48bf94114eeafeaec --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Type + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +from ..typing import Pointer, Int, Int32, Numeric, NumericMeta + + +SM100_TMEM_CAPACITY_COLUMNS = 512 +SM100_TMEM_MIN_ALLOC_COLUMNS = 32 + + +@dsl_user_op +def retrieve_tmem_ptr( + element_type: Type[Numeric], + alignment: int, + ptr_to_buffer_holding_addr: Pointer, + *, + loc=None, + ip=None, +) -> Pointer: + """ + Retrieves a pointer to TMEM with the provided element type and alignment. + + :param element_type: The pointee type of the pointer. + :type element_type: Type[Numeric] + :param alignment: The alignment of the result pointer + :type alignment: int + :param ptr_to_buffer_holding_addr: A pointer to a SMEM buffer holding the TMEM address of the + start of the allocation allocation + :type ptr_to_buffer_holding_addr: Pointer + :return: A pointer to TMEM + :rtype: Pointer + """ + if not isinstance(element_type, NumericMeta): + raise TypeError( + f"element_type must be a type of Numeric, but got {element_type}" + ) + + res_ty = _cute_ir.PtrType.get( + element_type.mlir_type, _cute_ir.AddressSpace.tmem, alignment + ) + return _cute_nvgpu_ir.arch_sm100_retrieve_tmem_ptr( + res_ty, ptr_to_buffer_holding_addr.value, loc=loc, ip=ip + ) + + +@dsl_user_op +def alloc_tmem( + num_columns: Int, + smem_ptr_to_write_address: Pointer, + is_two_cta=None, + *, + loc=None, + ip=None, +) -> None: + """ + Allocates TMEM. + + :param num_columns: The number of TMEM columns to allocate + :type num_columns: Int + :param smem_ptr_to_write_address: A pointer to a SMEM buffer where the TMEM address is written + to + :type smem_ptr_to_write_address: Pointer + :param is_two_cta: Optional boolean parameter for 2-CTA MMAs + """ + if isinstance(num_columns, int): + if ( + num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS + or num_columns > SM100_TMEM_CAPACITY_COLUMNS + or not (num_columns & (num_columns - 1) == 0) + ): + raise ValueError( + f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}" + ) + _cute_nvgpu_ir.arch_sm100_alloc_tmem( + Int32(num_columns).ir_value(loc=loc, ip=ip), + smem_ptr_to_write_address.value, + is_two_cta=is_two_cta, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def relinquish_tmem_alloc_permit(is_two_cta=None, *, loc=None, ip=None) -> None: + """ + Relinquishes the right to allocate TMEM so that other CTAs potentially in a different grid can + allocate. + """ + _cute_nvgpu_ir.arch_sm100_relinquish_tmem_alloc_permit( + is_two_cta=is_two_cta, loc=loc, ip=ip + ) + + +@dsl_user_op +def dealloc_tmem( + tmem_ptr: Pointer, + num_columns: Int, + is_two_cta=None, + *, + loc=None, + ip=None, +) -> None: + """ + Deallocates TMEM using the provided pointer and number of columns. + + :param tmem_ptr: A pointer to the TMEM allocation to de-allocate + :type tmem_ptr: Pointer + :param num_columns: The number of columns in the TMEM allocation + :type num_columns: Int + :param is_two_cta: Optional boolean parameter for 2-CTA MMAs + """ + if isinstance(num_columns, int): + if ( + num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS + or num_columns > SM100_TMEM_CAPACITY_COLUMNS + or not (num_columns & (num_columns - 1) == 0) + ): + raise ValueError( + f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}" + ) + _cute_nvgpu_ir.arch_sm100_dealloc_tmem( + tmem_ptr.value, + Int32(num_columns).ir_value(loc=loc, ip=ip), + is_two_cta=is_two_cta, + loc=loc, + ip=ip, + ) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py new file mode 100644 index 0000000000000000000000000000000000000000..12d5e4221a3e6007656a9400966e84d8b9a25a79 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py @@ -0,0 +1,7070 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import copy as py_copy +from dataclasses import dataclass +import inspect +import math +import operator +from abc import ABC, abstractmethod +from functools import lru_cache, partial, reduce +from inspect import isclass +from itertools import chain +from typing import ( + Callable, + Iterable, + overload, + List, + Tuple, + Union, + Type, + Any, + Dict, + Optional, +) +from enum import Enum, auto + +from cutlass.cutlass_dsl import ( + const, + T, + lru_cache_ir, + is_dynamic_expression, + for_generate, + yield_out, + if_generate, + extract_mlir_values, + new_from_mlir_values, + _binary_op_type_promote, + not_, + cutlass_arith, + dsl_user_op, +) + +from cutlass._mlir import ir +from cutlass._mlir.dialects._ods_common import get_op_result_or_op_results +from cutlass._mlir.dialects import cute as _cute_ir +from cutlass._mlir.dialects.cute import ( + ScaledBasis as _ScaledBasis, + Ratio as _Ratio, +) + +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects import llvm, builtin, vector, arith + +from .typing import ( + Numeric, + Integer, + NumericMeta, + Boolean, + Int32, + Int8, + Int16, + Int32, + Int64, + Float32, + TFloat32, + Int, + IntTuple, + Shape, + Stride, + Coord, + Layout, + Tile, + Tiler, + XTuple, + Tensor, + Pointer, + AddressSpace, + as_numeric, +) + + +#################################################################################################### +# +# Internal IntTuple helpers +# +#################################################################################################### + + +def _get_typed_value(x): + if isinstance(x, Integer): + return ( + x.value.get_typed_value() if isinstance(x.value, IntValue) else x.ir_value() + ) + else: + return x + + +def _pack_x(x, packer, op, *, loc=None, ip=None) -> ir.Value: + x = transform_leaf(_get_typed_value, x) + res_ty, dyn_elems = packer(x) + # <"0"> is deduced from type inference which should be removed for make_... operations + dyn_elems = [t for t in dyn_elems if not is_static(t)] + return op(res_ty, dyn_elems, loc=loc, ip=ip).result + + +def _pack_shape(shape: Shape, *, loc=None, ip=None) -> ir.Value: + _check_shape(shape) + return _pack_x(shape, _cute_ir.pack_shape, _cute_ir.MakeShapeOp, loc=loc, ip=ip) + + +def _pack_stride(stride: Stride, *, loc=None, ip=None) -> ir.Value: + _check_stride(stride) + # Convert basis elements to the base class before _pack_x + stride = transform_leaf( + lambda x: x.to(_cute_ir.ScaledBasis) if isinstance(x, ScaledBasis) else x, + stride, + ) + return _pack_x(stride, _cute_ir.pack_stride, _cute_ir.MakeStrideOp, loc=loc, ip=ip) + + +def _pack_coord(coord: Coord, *, loc=None, ip=None) -> ir.Value: + _check_coord(coord) + return _pack_x(coord, _cute_ir.pack_coord, _cute_ir.MakeCoordOp, loc=loc, ip=ip) + + +def _pack_int_tuple(int_tuple: IntTuple, *, loc=None, ip=None) -> ir.Value: + _check_int_tuple(int_tuple) + return _pack_x( + int_tuple, _cute_ir.pack_int_tuple, _cute_ir.MakeIntTupleOp, loc=loc, ip=ip + ) + + +def _pack_tile(tile: Tile, *, loc=None, ip=None) -> ir.Value: + _check_tile(tile) + + def expand_leaves(tile) -> list: + leaves = [] + for e in tile: + if isinstance(e, _Layout): + leaves.extend(list(flatten_to_tuple(e.shape))) + leaves.extend(list(flatten_to_tuple(e.stride))) + else: + leaves.append(e) + return leaves + + layout_leaves = flatten_to_tuple(tile) + dyn_elems = expand_leaves(layout_leaves) + dyn_elems = [ + _get_typed_value(x) for x in dyn_elems if isinstance(x, (Integer, ir.Value)) + ] + + res_ty = _cute_ir.pack_tile(tile) + return _cute_ir.make_tile(res_ty, dyn_elems, loc=loc, ip=ip) + + +def _unpack_x_tuple(t: Union[ir.Type, ir.Value], *, loc=None, ip=None) -> XTuple: + # If t is an MLIR type, make sure it's static and make a Value + if isinstance(t, ir.Type): + if not _cute_ir.is_static(t): + raise ValueError() + t = _cute_ir.static(t) + + if isinstance(t, ir.Value): + input_ty = t.type + if t.type.rank == 0: + # Handle this case separately, _cute_ir.get_leaves will return an Op in this case + vals = [] + else: + vals = _cute_ir.get_leaves(t, loc=loc, ip=ip) + if not isinstance(vals, list): + vals = [vals] + else: + raise TypeError(f"expects static type or value, but got {t}") + + # CuTe IR only supports Int32 for now. Need to support detection of other types + res = _cute_ir.unpack_x_tuple(input_ty, vals) + + def post_process(x): + if isinstance(x, _cute_ir.ScaledBasis): + return ScaledBasis(post_process(x.get_value()), x.get_mode()) + elif isinstance(x, _cute_ir.Ratio): + return Ratio(x.numerator, x.denominator) + else: + return x + + return transform_leaf(post_process, res) + + +#################################################################################################### +# Validation helpers +#################################################################################################### + + +def _check_shape(shape: Shape) -> None: + if is_integer(shape): + if isinstance(shape, int): + if shape <= 0: + raise ValueError( + f"Expected size in shape to be strictly positive, but got {shape}" + ) + elif isinstance(shape, Integer): + pass + else: + raise TypeError(f"Expected size be int or Integer, but got {type(shape)}") + elif isinstance(shape, tuple): + for s in shape: + _check_shape(s) + else: + raise ValueError( + f"Expected Shape, which is a positive integer or tuple of Shapes, but got {shape}" + ) + + +def _check_coord(coord: Coord) -> None: + flat_coord = flatten_to_tuple(coord) + if not all(is_integer(c) or c is None for c in flat_coord): + raise ValueError( + f"Expected Coord, whose leaves are integers or None, but got {coord}" + ) + + +def _check_stride(stride: Stride) -> None: + flat_stride = flatten_to_tuple(stride) + if not all(is_integer(s) or isinstance(s, ScaledBasis) for s in flat_stride): + raise ValueError( + f"Expected Stride, whose leaves are integers or ScaledBasis, but got {stride}" + ) + + +def _check_int_tuple(int_tuple: IntTuple) -> None: + flat_int_tuple = flatten_to_tuple(int_tuple) + if not all(is_integer(d) for d in flat_int_tuple): + raise ValueError( + f"Expected IntTuple, whose leaves are integers, but got {int_tuple}" + ) + + +def _check_tile(tile: Tile) -> None: + flat_tile = flatten_to_tuple(tile) + if not all(is_integer(t) or isinstance(t, _Layout) or t is None for t in flat_tile): + raise ValueError( + f"Expected Tile, whose leaves are integers or Layout or None, but got {tile}" + ) + + +#################################################################################################### +# +# Core types +# +#################################################################################################### + + +class IntValue(cutlass_arith.ArithValue): + """Internal representation of constrained integer types with divisibility information. + + IntValue serves as a proxy for constrained integer types in the CuTe IR. Rather than + directly storing values of IntTupleType with depth=0, it stores the result of the + `cute.get_scalars` operation applied to such values. + + This class represents the following sequence of operations in the IR: + %0 = ... : (...) -> !cute.int_tuple<"?"> + %1 = cute.get_scalars(%0) : (!cute.int_tuple<"?">) -> i32 + + where the first operation produces a `cute.int_tuple<"?">` with depth=0 and rank=1. It + automatically emit `cute.get_scalars` and track it. + + IntValue inherits behavior from ArithValue with the following extensions: + * Overloaded operations that accept IntTupleType values to propagate divisibility information + * Support for CuTe operations that utilize divisibility constraints + + API for interacting with IntValue: + * get_typed_value() - Returns the value as an IntTupleType + * get_divisibility() - Returns the divisibility constraint of the value + """ + + def __init__(self, v, signed=True): + # Cute Constrained Int Type is always signed + if isinstance(v, int): + v = _pack_int_tuple(v) + + if isinstance(v.type, _cute_ir.IntTupleType): + scalar_val = _cute_ir.get_scalars(v) + super().__init__(scalar_val, True) + else: + super().__init__(v, True) + + def get_typed_value(self): + if isinstance(self.type, ir.IntegerType): + def_op = self.owner.operation + if def_op.name == "cute.get_scalars": + return def_op.operands[0] + + assert not isinstance(self.type, _cute_ir.IntTupleType) + + return _pack_int_tuple(self) + + @property + def divisibility(self): + if isinstance(self.get_typed_value().type, _cute_ir.IntTupleType): + return self.get_typed_value().type.get_divisibility([0]) + else: + return 1 + + def __str__(self): + if self.divisibility == 1: + return f"?" + else: + return f"?{{div={self.divisibility}}}" + + def __repr__(self): + parent_name = cutlass_arith.ArithValue.__name__ + return super().__str__().replace(parent_name, IntValue.__name__) + + def pretty_str(self): + return self.__str__() + + @staticmethod + def _binary_op(op): + def wrapper(self, other, **kwargs): + if isinstance(other, IntValue): + other_val = other.get_typed_value() + elif isinstance(other, ir.Value) and isinstance( + other.type, _cute_ir.IntTupleType + ): + other_val = other + elif isinstance(other, ir.Value) and isinstance(other.type, ir.IntegerType): + other = cutlass_arith.int_to_int(other, Int32, **kwargs) + other_val = _pack_int_tuple(other) + elif isinstance(other, (int, bool)): + other_val = _pack_int_tuple(int(other)) + else: + # Dispatch to `__rmul__` of `other` + return NotImplemented + + return IntValue(op(self, other_val, **kwargs)) + + return wrapper + + @dsl_user_op + @_binary_op + def __add__(self, other, *, loc=None, ip=None): + return _cute_ir.add_offset(self.get_typed_value(), other, loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __sub__(self, other, *, loc=None, ip=None): + return _cute_ir.tuple_sub(self.get_typed_value(), other, loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __mul__(self, other, *, loc=None, ip=None): + return _cute_ir.tuple_mul(self.get_typed_value(), other, loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __floordiv__(self, other, *, loc=None, ip=None) -> "IntValue": + return _cute_ir.tuple_div(self.get_typed_value(), other, loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __mod__(self, other, *, loc=None, ip=None) -> cutlass_arith.ArithValue: + return _cute_ir.tuple_mod(self.get_typed_value(), other, loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __radd__(self, other, *, loc=None, ip=None) -> "IntValue": + return _cute_ir.add_offset(other, self.get_typed_value(), loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __rsub__(self, other, *, loc=None, ip=None) -> "IntValue": + return _cute_ir.tuple_sub(other, self.get_typed_value(), loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __rmul__(self, other, *, loc=None, ip=None): + return _cute_ir.tuple_mul(other, self.get_typed_value(), loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __rfloordiv__(self, other, *, loc=None, ip=None) -> "IntValue": + return _cute_ir.tuple_div(other, self.get_typed_value(), loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __rmod__(self, other, *, loc=None, ip=None) -> "IntValue": + return _cute_ir.tuple_mod(other, self.get_typed_value(), loc=loc, ip=ip) + + +class Ratio(_Ratio): + """A class representing a rational number as a ratio of two integers. + + Ratio is used in CuTe to represent exact fractional values that arise in + tensor layout operations, particularly in composition operations where + divisibility conditions may not be satisfied. + + :param numerator: The numerator of the ratio + :type numerator: int + :param denominator: The denominator of the ratio + :type denominator: int + :raises TypeError: If numerator or denominator are not integers + """ + + def __init__(self, numerator: int, denominator: int): + if not isinstance(numerator, int) or not isinstance(denominator, int): + raise TypeError( + f"numerator and denominator must be integers, but got {numerator} and {denominator}" + ) + super().__init__(numerator, denominator) + + def is_integral(self) -> bool: + """Check if the ratio represents an integer value. + + :return: True if the numerator is divisible by the denominator + :rtype: bool + """ + return super().is_integral() + + def reduced(self) -> "Ratio": + """Return a new Ratio with the numerator and denominator reduced to lowest terms. + + :return: A new Ratio in reduced form + :rtype: Ratio + """ + res = super().reduced() + return Ratio(res.numerator, res.denominator) + + def __mul__(self, other): + """Multiply this ratio by another ratio or an integer. + + :param other: The value to multiply by + :type other: Union[Ratio, int] + :return: A new ratio representing the product + :rtype: Ratio + :raises TypeError: If other is not a Ratio or int + """ + if isinstance(other, Ratio): + return Ratio( + self.numerator * other.numerator, + self.denominator * other.denominator, + ) + elif isinstance(other, int): + return Ratio(self.numerator * other, self.denominator) + else: + raise TypeError(f"Cannot multiply Ratio with {type(other)}") + + def __rmul__(self, other): + """Right multiplication operation. + + :param other: The value to multiply by + :type other: Union[Ratio, int] + :return: A new ratio representing the product + :rtype: Ratio + """ + return self.__mul__(other) + + def __str__(self): + """String representation of the ratio. + + :return: String in the format "numerator/denominator" + :rtype: str + """ + return super().__str__() + + def to(self, dtype): + """Convert the ratio to another type. + + :param dtype: The target type for conversion + :type dtype: type + :return: The ratio converted to the specified type + :raises TypeError: If conversion to the specified type is not supported + """ + if dtype is Ratio: + return self + elif dtype is float: + return self.numerator / self.denominator + elif dtype is int: + return self.numerator // self.denominator + elif issubclass(dtype, _Ratio): + return self + else: + raise TypeError(f"Cannot convert Ratio to {dtype}") + + +class ScaledBasis: + """A class representing a scaled basis element in CuTe's layout algebra. + + ScaledBasis is used to represent elements in the layout algebra, particularly + in the context of composition operations. It consists of a value (scale) and + a mode that identifies mode of the basis element. + + :param value: The scale value + :type value: Union[int, Integer, Ratio, ir.Value] + :param mode: The mode identifying the basis element + :type mode: Union[int, List[int]] + :raises TypeError: If mode is not an integer or list of integers + + **Examples:** + + .. code-block:: python + + # Create a scaled basis with integer scale and mode + sb1 = ScaledBasis(2, 0) # 2 * E(0) + + # Create a scaled basis with a Ratio scale + sb2 = ScaledBasis(Ratio(1, 2), 1) # (1/2) * E(1) + + # Create a scaled basis with a list of modes + sb3 = ScaledBasis(4, [0, 1]) # 4 * E([0, 1]) + + # Scaled basis elements are commonly used in layout strides + layout = make_layout((4, 8), stride=(ScaledBasis(2, 0), ScaledBasis(1, 1))) + + # This creates a layout with strides (2@0, 1@1) representing + # a coordinate system where each dimension has its own basis + + # Example: Mapping coordinates to indices using the layout + coord = (2, 3) + idx = crd2idx(coord, layout) # Maps (2, 3) to (4, 3) + """ + + def __init__(self, value, mode) -> None: + if isinstance(mode, int): + self._mode = [mode] + else: + if any(not isinstance(x, int) for x in mode): + raise TypeError("Mode must be a list of integers") + self._mode = mode + + self._value = value + + def is_static(self) -> bool: + """Check if the value is statically known. + + :return: True if the value is not a dynamic expression + :rtype: bool + """ + return not is_dynamic_expression(self._value) + + def to(self, dtype): + """Convert to another type. + + :param dtype: The target type for conversion + :type dtype: type + :return: The ScaledBasis converted to the specified type + :raises TypeError: If conversion to the specified type is not supported + """ + if dtype is ScaledBasis: + return self + elif dtype is _ScaledBasis: + if isinstance(self._value, Ratio): + scale = self._value + elif isinstance(self._value, Integer): + scale = self._value.ir_value() + else: + scale = self._value + + if isinstance(scale, IntValue): + return _ScaledBasis(scale.get_typed_value(), self._mode) + else: + return _ScaledBasis(scale, self._mode) + else: + raise TypeError(f"Cannot convert ScaledBasis to {dtype}") + + def __str__(self): + return f"{self.to(_ScaledBasis).__str__()}" + + def __hash__(self): + if isinstance(self.mode, list): + return hash((self.value, tuple(self.mode))) + else: + return hash((self.value, self.mode)) + + @property + def value(self): + """Get the scale value. + + :return: The scale value + """ + return self._value + + @property + def mode(self) -> List[int]: + """Get the mode identifying the basis element. + + :return: The mode as a list of integers + :rtype: List[int] + """ + return self._mode + + def __eq__(self, other): + if isinstance(other, ScaledBasis): + return self.value == other.value and self.mode == other.mode + else: + return False + + def __rmul__(self, scale: Union[Int, ir.Value, Ratio]) -> "ScaledBasis": + """Right multiplication by a scale factor. + + This operation is used in layout algebra to scale basis elements, + which is essential for operations like composition and partitioning. + + :param scale: The scale factor + :type scale: Union[Int, ir.Value, Ratio] + :return: A new scaled basis element + :rtype: ScaledBasis + :raises TypeError: If scale is not of a supported type + :raises NotImplementedError: If scaling a basis element with a ratio value + """ + if not isinstance(scale, (int, Integer, Ratio, ir.Value)): + raise TypeError( + f"scale must be an integer or a ratio, but got {type(scale)}" + ) + if isinstance(self.value, Ratio): + raise NotImplementedError( + "scaling a basis element having a ratio is not supported" + ) + + value = self.value + + if not isinstance(value, (Integer, Ratio, int, cutlass_arith.ArithValue)): + raise TypeError(f"Don't support {type(value)} for ScaledBasis") + + # Lift to IntValue type to preserve type info as much as possible + if isinstance(scale, cutlass_arith.ArithValue): + scale = IntValue(_pack_int_tuple(cutlass_arith.int_to_int(scale, Int32))) + + if isinstance(value, cutlass_arith.ArithValue): + value = IntValue(_pack_int_tuple(cutlass_arith.int_to_int(value, Int32))) + elif isinstance(value, Integer): + value = value.ir_value() + + return ScaledBasis(scale * value, self.mode) # type: ignore + + +def E(mode: Union[int, List[int]]) -> ScaledBasis: + """Create a unit ScaledBasis element with the specified mode. + + This function creates a ScaledBasis with value 1 and the given mode. + The mode represents the coordinate axis or dimension in the layout. + + :param mode: The mode (dimension) for the basis element, either a single integer or a list of integers + :type mode: Union[int, List[int]] + :return: A ScaledBasis with value 1 and the specified mode + :rtype: ScaledBasis + :raises TypeError: If mode is not an integer or a list + + **Examples:** + + .. code-block:: python + + # Create a basis element for the first dimension (mode 0) + e0 = E(0) + + # Create a basis element for the second dimension (mode 1) + e1 = E(1) + + # Create a basis element for a hierarchical dimension + e_hier = E([0, 1]) + """ + if isinstance(mode, int): + mode = [mode] + + if not isinstance(mode, list): + raise TypeError(f"expects a list, got {type(mode)}") + + if not mode: + return 1 + + return ScaledBasis(1, mode) + + +def get_divisibility(x: Union[int, Integer]) -> int: + if isinstance(x, int): + return x + + if isinstance(x, Integer): + x = x.value + + if isinstance(x, IntValue): + return x.divisibility + else: + return 1 + + +@ir.register_value_caster(_cute_ir.SwizzleType.get_static_typeid(), replace=True) +class Swizzle(ir.Value): + """ + Swizzle is a transformation that permutes the elements of a layout. + + Swizzles are used to rearrange data elements to improve memory access patterns + and computational efficiency. + + Swizzle is defined by three parameters: + - MBase: The number of least-significant bits to keep constant + - BBits: The number of bits in the mask + - SShift: The distance to shift the mask + + The mask is applied to the least-significant bits of the layout. + + .. code-block:: + + 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx + ^--^ MBase is the number of least-sig bits to keep constant + ^-^ ^-^ BBits is the number of bits in the mask + ^---------^ SShift is the distance to shift the YYY mask + (pos shifts YYY to the right, neg shifts YYY to the left) + + e.g. Given + 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx + + the result is + 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ `xor` YY + + """ + + def __str__(self): + # Cut off the MLIR type's string for making pretty_str more concise + return self.type.__str__()[15 : 15 + 8] + + +@ir.register_value_caster(_cute_ir.LayoutType.get_static_typeid(), replace=True) +class _Layout(Layout): + """Layout is CuTe's core abstraction for representing tensor layouts. + + A Layout maps from a logical coordinate space to an index space, defined by a + pair of (Shape, Stride). The Shape defines the abstract dimensions of the Layout, + while the Stride defines how coordinates within the Shape map to linear indices. + + Layouts present a common interface to multidimensional array access that abstracts + away the details of how array elements are organized in memory. This allows algorithms + to be written generically, so that layouts can change without requiring code changes. + + CuTe layouts are inherently hierarchical, constructed from smaller, nested layouts + that can represent complex mappings required by GPU tensor instructions. They support + a rich algebra of operations including concatenation, coalescence, composition, + complement, and inversion. + + :ivar shape: An IntTuple representing the dimensions of the layout. + :ivar stride: An IntTuple representing the strides of the layout. + :ivar max_alignment: The maximum alignment of the layout. + + **Examples:** + + .. code-block:: python + + # Creating a layout with shape (4,8) and default stride (layout left / "column major") + layout = cute.make_layout((4, 8)) + + # Creating a layout with explicit shape and stride + layout = cute.make_layout((4, 8), stride=(8, 1)) + + # Accessing a specific coordinate: (2, 3) -> 2 * 8 + 3 * 1 = 19 + idx = cute.crd2idx((2, 3), layout) + """ + + def __init__(self, op_result) -> None: + """Initialize a Layout object. + + :param op_result: The operation result value to wrap. + """ + super().__init__(op_result) + + def __str__(self) -> str: + """Return a string representation of the layout. + + :return: A string in the format "shape:stride". + """ + return f"{pretty_str(self.shape)}:{pretty_str(self.stride)}" + + @property + def shape(self, *, loc=None, ip=None) -> Shape: + """Get the shape of the layout. + + The shape defines the dimensions and structure of the layout's + coordinate space. + + :param loc: Optional location information for debugging. + :param ip: Optional insertion point for IR generation. + :return: The hierarchical shape of the layout. + """ + return _unpack_x_tuple(_cute_ir.get_shape(self, loc=loc, ip=ip), loc=loc, ip=ip) + + @property + def stride(self, *, loc=None, ip=None) -> Stride: + """Get the stride of the layout. + + The stride defines how coordinates map to linear indices in memory. + + :param loc: Optional location information for debugging. + :param ip: Optional insertion point for IR generation. + :return: The hierarchical stride of the layout. + """ + return _unpack_x_tuple( + _cute_ir.get_stride(self, loc=loc, ip=ip), loc=loc, ip=ip + ) + + @property + def max_alignment(self) -> int: + """Get the maximum alignment of the layout. + + :return: The maximum alignment in bytes. + """ + return self.type.max_alignment + + def __eq__(self, other) -> Union[bool, Boolean]: + """Check if this layout is equal to another layout. + + Two layouts are equal if they have the same shape and stride. + + :param other: The layout to compare with. + :return: True if layouts are equal, False otherwise. + May return an IR value for dynamic layouts. + """ + if isinstance(other, Layout): + if is_static(self.type) and is_static(other.type): + return self.type == other.type + return Boolean(_cute_ir.equal(self, other)) + else: + return False + + def __req__(self, other) -> Union[bool, Boolean]: + """Reflected equality check. + + :param other: The layout to compare with. + :return: Result of other.__eq__(self). + """ + if isinstance(other, Layout): + return other.__eq__(self) + return False + + def __ne__(self, other) -> Union[bool, Boolean]: + """Check if this layout is not equal to another layout. + + :param other: The layout to compare with. + :return: True if layouts are not equal, False otherwise. + """ + if isinstance(other, Layout): + if is_static(self.type) and is_static(other.type): + return self.type != other.type + return Boolean(not_(_cute_ir.equal(self, other))) + else: + return True + + def __rne__(self, other) -> Union[bool, Boolean]: + """Reflected inequality check. + + :param other: The layout to compare with. + :return: Result of other.__ne__(self). + """ + if isinstance(other, Layout): + return other.__ne__(self) + return False + + def __getitem__(self, idx: int) -> Layout: + """ + Top-level `get` to provide a syntax similar to `tuple`. + """ + return get(self, mode=[idx]) + + @dsl_user_op + def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: + return crd2idx(coord, self, loc=loc, ip=ip) + + @dsl_user_op + def get_hier_coord(self, idx, *, loc=None, ip=None) -> Coord: + """Get the hierarchical coordinate corresponding to a linear index. + + This method maps from a linear index back to the logical coordinate + in the layout's coordinate space. + + :param idx: The linear index to convert. + :return: The hierarchical coordinate corresponding to the index. + + **Examples:** + + .. code-block:: python + + layout = make_layout((4, 8), stride=(8, 1)) + + # map linear index back to coordinate: 5 -> (1, 1) + coord = get_hier_coord(5, layout) + """ + idx_val = Int32(idx).ir_value() + crd = _cute_ir.get_hier_coord(idx_val, self, loc=loc, ip=ip) + return _unpack_x_tuple(crd) + + @dsl_user_op + def get_flat_coord(self, idx, *, loc=None, ip=None) -> Coord: + idx_val = Int32(idx).ir_value() + res = _cute_ir.get_flat_coord(idx_val, self, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +@ir.register_value_caster(_cute_ir.ComposedLayoutType.get_static_typeid(), replace=True) +class ComposedLayout(ir.Value): + r"""ComposedLayout represents the functional composition of layouts in CuTe. + + A ComposedLayout is formed by the composition of three components: + inner o offset o outer, where: + + - inner: The inner layout or swizzle that is applied last + - offset: An integer tuple representing a coordinate offset + - outer: The outer layout that is applied first + + ComposedLayout implements the functional composition operation where: + + .. math:: + + R(c) := (inner \\circ offset \\circ outer)(c) := inner(offset + outer(c)) + + This composition allows for complex transformations of coordinates and indices, + enabling operations like tiling, partitioning, and reshaping of data. + + :ivar inner: The inner layout or swizzle component + :ivar offset: The coordinate offset applied between inner and outer layouts + :ivar outer: The outer layout component + :ivar max_alignment: The maximum alignment of the composed layout + + **Examples:** + + .. code-block:: python + + # Create a composed layout with inner layout, offset, and outer layout + + # inner layout: (4, 8):(1, 4) + inner_layout = make_layout((4, 8)) + + offset = (0, 0) + + # outer layout: (2, 2):(1@0, 1@1) + outer_layout = make_layout((2, 2), stride=(1 * E(0), 1 * E(1))) + + # composed layout: (inner o offset o outer) + composed = make_composed_layout(inner_layout, offset, outer_layout) + + # Accessing components of the composed layout + inner = composed.inner + offset = composed.offset + outer = composed.outer + + # map coordinate (0, 1) to linear index + # - outer(0, 1) = (0, 1) + # - offset + outer(0, 1) = (0, 1) + # - inner(0, 1) = 0 * 1 + 1 * 4 = 4 + idx = crd2idx((0, 1), composed) + + # Composition is used in many tiling operations + # For example, in logical_product, raked_product, and blocked_product + """ + + def __init__(self, value) -> None: + """Initialize a ComposedLayout object. + + :param value: The operation result value to wrap. + """ + super().__init__(value) + + def __str__(self) -> str: + return f"{pretty_str(self.inner)} o {pretty_str(self.offset)} o {pretty_str(self.outer)}" + + @property + def inner(self, *, loc=None, ip=None) -> Union[Swizzle, Layout]: + return _cute_ir.composed_get_inner(self, loc=loc, ip=ip) + + @property + def offset(self, *, loc=None, ip=None) -> IntTuple: + return _unpack_x_tuple(_cute_ir.composed_get_offset(self, loc=loc, ip=ip)) + + @property + def outer(self, *, loc=None, ip=None) -> Layout: + return _cute_ir.composed_get_outer(self, loc=loc, ip=ip) + + @property + def shape(self, *, loc=None, ip=None) -> Shape: + return _unpack_x_tuple(_cute_ir.get_shape(self, loc=loc, ip=ip), loc=loc, ip=ip) + + @property + def max_alignment(self) -> int: + return self.type.max_alignment + + def __eq__(self, other) -> Union[bool, Boolean]: + if isinstance(other, ComposedLayout): + if is_static(self.type) and is_static(other.type): + return self.type == other.type + else: + raise NotImplementedError( + f"runtime comparison of composed layouts is not supported, got `{self}` and `{other}`" + ) + else: + return False + + def __req__(self, other) -> Union[bool, Boolean]: + if isinstance(other, ComposedLayout): + return Boolean(other.__eq__(self)) + return False + + def __ne__(self, other) -> Union[bool, Boolean]: + return not self.__eq__(other) + + def __rne__(self, other) -> Union[bool, Boolean]: + if isinstance(other, ComposedLayout): + return other.__ne__(self) + return False + + def __getitem__(self, idx: int) -> "ComposedLayout": + """ + Top-level `get` to provide a syntax similar to `tuple`. + """ + return get(self, mode=[idx]) + + @dsl_user_op + def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: + return crd2idx(coord, self, loc=loc, ip=ip) + + +@ir.register_value_caster(_cute_ir.PtrType.get_static_typeid(), replace=True) +class _Pointer(Pointer): + """ + A pointer class representing a memory address with specific properties. + + Pointers are a fundamental type of iterator/engine that support random-access operations. + They can be offset by elements of a layout's codomain and dereferenced to produce values. + + :param value: The MLIR operation result value to initialize the pointer with + :type value: ir.Value + + :ivar type: The MLIR type of the pointer + :vartype type: Type + :ivar value_type: The type of value this pointer points to + :vartype value_type: Type + :ivar memspace: The memory space where the pointer data resides (e.g., gmem, smem, rmem) + :vartype memspace: AddressSpace + + :note: When composed with a layout, a pointer forms a tensor: T = E ∘ L, where E is the pointer + and L is the layout. The tensor evaluates the layout by mapping a coordinate c to the + codomain, offsets the pointer accordingly, and dereferences the result: + T(c) = (E ∘ L)(c) = *(E + L(c)) + """ + + def __init__(self, value) -> None: + assert isinstance(value, ir.Value) + self.value = ir.Value(value) + + def __str__(self) -> str: + # Cut off the MLIR type's string for making pretty_str more concise + return self.type.__str__()[6:] + + def __get_mlir_types__(self): + return [self.value.type] + + def __extract_mlir_values__(self): + return [self.value] + + def __new_from_mlir_values__(self, values): + # Only expecting single value of _Pointer instance or ir.Value + # In this context, a _Pointer instance is an encapsulated ir.Value which is automatically created + # by value caster for cute.ptr typed values + assert len(values) == 1, f"Expected 1 value, but got {len(values)}" + assert isinstance( + values[0], (_Pointer, ir.Value) + ), f"Expected _Pointer or ir.Value, but got {type(values[0])}" + return _Pointer( + values[0] if isinstance(values[0], ir.Value) else values[0].value + ) + + @property + @lru_cache_ir() + def dtype(self) -> Type[Numeric]: + return Numeric.from_mlir_type(self.value.type.value_type) + + @property + def alignment(self) -> int: + return self.type.alignment + + @property + def max_alignment(self) -> int: + return self.type.max_alignment + + @property + @lru_cache_ir() + def memspace(self) -> AddressSpace: + return AddressSpace(self.type.address_space) + + # Make it behave as if it inherited from ir.Value + @property + @lru_cache_ir() + def type(self) -> ir.Type: + return self.value.type + + # Only use if you absolutely need to get the LLVM pointer Value + @property + @lru_cache_ir() + def llvm_ptr(self, *, loc=None, ip=None) -> ir.Value: + """ + Get the LLVM pointer representation of this pointer. + + :param loc: The source location for the operation, defaults to None + :type loc: Location, optional + :param ip: The insertion point for the operation, defaults to None + :type ip: InsertionPoint, optional + :return: The LLVM pointer representation + :rtype: ir.Value + """ + llvm_ptr_ty = llvm.PointerType.get(self.memspace.value) + return builtin.unrealized_conversion_cast( + [llvm_ptr_ty], [self.value], loc=loc, ip=ip + ) + + def __add__(self, offset: IntTuple) -> Pointer: + """ + Offset the pointer by elements of a layout's codomain. + + :param offset: The offset to add to the pointer + :type offset: IntTuple + :return: A new pointer offset by the specified amount + :rtype: ir.Value + """ + offset = _pack_int_tuple(offset) + return _cute_ir.add_offset(self.value, offset=offset) + + @dsl_user_op + def toint(self, *, loc=None, ip=None): + if self.memspace in (AddressSpace.gmem, AddressSpace.generic): + res_type = Int64 + else: + res_type = Int32 + + return res_type( + _cute_ir.ptrtoint(res_type.mlir_type, self.value, loc=loc, ip=ip) + ) + + @dsl_user_op + def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: + """ + Align a pointer to a specified byte alignment. + + :param min_align: The minimum byte alignment requirement. Must be a power of 2. + :type min_align: int + :param loc: The source location for the operation, defaults to None + :type loc: Location, optional + :param ip: The insertion point for the operation, defaults to None + :type ip: InsertionPoint, optional + :return: The aligned new pointer that satisfies alignment request. + :rtype: Pointer + :raises ValueError: If the alignment is not a power of 2. + :raises TypeError: If pointer is in tmem address space. + """ + + if (min_align & (min_align - 1)) != 0: + raise ValueError("Alignment must be a power of 2") + + assert isinstance(self.type, _cute_ir.PtrType) + if self.memspace is AddressSpace.tmem: + raise ValueError("aligning a TMEM pointer is not supported") + + if min_align <= self.alignment: + return self + + dtype = Numeric.from_mlir_type(self.type.value_type) + # Convert pointer to integer + address_int = self.toint(loc=loc, ip=ip) + # Align the address + aligned_address = (address_int + min_align - 1) & ~(min_align - 1) + + return make_ptr( + dtype, + aligned_address, + self.memspace, + assumed_align=min_align, + loc=loc, + ip=ip, + ) + + +@ir.register_value_caster(_cute_ir.MemRefType.get_static_typeid(), replace=True) +@ir.register_value_caster(_cute_ir.CoordTensorType.get_static_typeid(), replace=True) +@ir.register_value_caster( + _cute_nvgpu_ir.SmemDescViewType.get_static_typeid(), replace=True +) +class _Tensor(Tensor): + """A tensor class representing the composition of an iterator (engine) with a layout. + + A tensor evaluates the layout by mapping a coordinate to the codomain, offsets the + iterator accordingly, and dereferences the result to obtain the tensor's value. + Formally: T(c) = (E ∘ L)(c) = *(E + L(c)), where E is the iterator/engine and L is the layout. + + :param value: The MLIR operation result value to initialize the tensor with + :type value: ir.Value + :param dtype: The user specified data type of the tensor elements. It could be \ + different from the underlying dtype in the iterator. The default is None. + :type dtype: Type[Numeric], optional + + Attributes: + iterator: The pointer or iterator (engine) component of the tensor + layout: The layout component defining the mapping from coordinates to offsets + shape: The shape of the tensor, inherited from the layout + stride: The stride of the tensor, inherited from the layout + element_type: The data type of the tensor elements + memspace: The memory space where the tensor data resides + + Notes: + - The tensor supports both direct element access via coordinates and slicing operations + - Load/store operations are only supported for specific memory spaces (rmem, smem, gmem, generic) + - For composed layouts, stride information is not directly accessible + - Dynamic layouts do not support vector load/store operations + + **Examples:** + + .. code-block:: python + + # Create a tensor with shape (4,8) in row-major layout + tensor = make_tensor(ptr, make_layout(shape=(4,8), stride=(8,1))) + + # Access individual element + val = tensor[0, 0] # or val = tensor[(0, 0)] + + # Slice operation - get first column + subtensor = tensor[None, 0] # or subtensor = tensor[(None, 0)] + """ + + def __init__(self, value, dtype: Optional[Type[Numeric]] = None): + self._dtype = dtype + if isinstance(value, ir.Value): + self.value = value + elif isinstance(value, _Tensor): + self.value = value.value + else: + raise TypeError(f"Expected ir.Value or core._Tensor, got {type(value)}") + + # Set iterator + iter_val = _cute_ir.get_iter(self.value) + if isinstance(iter_val, Pointer): + self._iterator = iter_val + elif isinstance(iter_val.type, _cute_ir.IntTupleType): + self._iterator = _unpack_x_tuple(iter_val) + elif isinstance(iter_val, ir.Value): + # Example: SMEM descriptor iterator, not well supported today + self._iterator = iter_val + else: + raise TypeError(f"unsupported iterator type, got {type(iter_val)}") + + # Set dtype + if self._dtype is None: + if is_int_tuple(self.iterator): + self._dtype = IntTuple + elif isinstance(self.iterator, Pointer): + self._dtype = self.iterator.value_type + elif isinstance(self.type, _cute_nvgpu_ir.SmemDescViewType): + # SmemDescViewType do not need dtype + self._dtype = None + else: + raise TypeError(f"unsupported iterator type, got {type(self.iterator)}") + + def __str__(self): + return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>" + + def __extract_mlir_values__(self): + return [self.value] + + def __new_from_mlir_values__(self, values): + # Only expecting single value of _Tensor or ir.Value + # In this context, a _Tensor instance is an encapsulated ir.Value which is automatically created + # by value caster for MemRef/CoordTensor/SmemDescView typed values + assert len(values) == 1, f"Expected 1 value, but got {len(values)}" + assert isinstance( + values[0], (_Tensor, ir.Value) + ), f"Expected _Tensor or ir.Value, but got {type(values[0])}" + return _Tensor( + values[0] if isinstance(values[0], ir.Value) else values[0].value, + dtype=self.element_type, + ) + + # Cheat to let `Type(_Tensor())` to return cute.Tensor + @property + def __class__(self) -> Type[Tensor]: + return Tensor + + # Make it behave as if it inherited from ir.Value + @property + @lru_cache_ir() + def type(self) -> ir.Type: + return self.value.type + + @dsl_user_op + def __getitem__( + self, crd: Coord, *, loc=None, ip=None + ) -> Union[Tensor, Numeric, IntTuple]: + """Access or slice tensor elements using coordinates. + + This method implements + * tensor evaluation T(c) = *(E + L(c)) when `c` is a coordinate without slicing, or + * tensor slicing operations T(c) = make_tensor(E + L(c), slice(L, c)) + where E is the iterator/engine and L is the layout + + :param crd: Coordinate or slice specification for accessing tensor elements + :type crd: Coord + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: Tensor element value or sliced subtensor + :rtype: Union[Tensor, ir.Value, IntTuple] + + :raises ValueError: If coordinate access is invalid for the tensor layout + + **Examples:** + + .. code-block:: python + + # Create a tensor with pointer iterator + ptr = make_ptr(cutlass.Float32, 0, cutlass.AddressSpace.gmem) + layout = make_layout((64, 128)) # leftmost mode is major + tensor = make_tensor(ptr, layout) # Tensor using pointer iterator + + # Direct element access loads from memory + val = tensor[0] # Loads element at offset 0 + val = tensor[1] # Loads element at offset 4 (4bytes per Float32) + val = tensor[(0, 1)] # Loads element at offset 64 + + # Create a coord tensor + layout = make_layout((64, 128), stride=(1 * E(0), 1 * E(1))) + tensor = make_tensor((128, 128), layout) + + # Direct element access + val = tensor[0] # Returns (128, 128) + val = tensor[(0, 1)] # Returns (128, 129) + + # Slice access + sliced = view[(3, None)] # Returns tensor slice + + .. note:: + Sub-byte types like Float4E2M1FN and Float6E3M2FN are not supported for scalar + dereference operations. Attempting to set individual elements of tensors with + these element types will result in errors. + + **Examples:** + + .. code-block:: python + + # Unsupported operations with sub-byte types: + ptr = make_ptr(cutlass.Float4E2M1FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + # The following will raise an error: + val = tensor[0] # Error: sub-byte scalar dereference not supported + + # Similarly for other sub-byte types: + ptr = make_ptr(cutlass.Float6E3M2FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + val = tensor[0] # Error: sub-byte scalar dereference not supported + """ + if has_underscore(crd): + return slice_(self.value, crd) + elif isinstance(self.type, _cute_ir.CoordTensorType): + res = _cute_ir.get_iter(slice_(self, crd).value, loc=loc, ip=ip) + return _unpack_x_tuple(res) + else: + self._check_can_load_store() + self._check_can_dereference() + + crd_val = _pack_coord(crd, loc=loc, ip=ip) + data_val = _cute_ir.memref_load(self.value, crd_val, loc=loc, ip=ip) + return self.element_type(data_val) + + def _cvt_to_dest(self, data: Union["TensorSSA", Numeric], *, loc=None, ip=None): + orig_dtype = data.dtype + # Implicit upcast to wider type + if ( + data.dtype.is_same_kind(self.element_type) + and self.element_type.width >= data.dtype.width + ): + data = data.to(self.element_type, loc=loc, ip=ip) # type: ignore + + if data.dtype.width != self.element_type.width: + raise ValueError( + f"Type mismatch, store {orig_dtype} (-> {data.dtype}) " + f"to Tensor with element type {self.element_type}" + ) + + if data.dtype is Boolean and self.element_type is Boolean: + # Boolean Numeric and Boolean TensorSSA both hold i1 value, but we need int8 value store to memory + val = data.ir_value_int8() + else: + val = data.ir_value() + return val + + @dsl_user_op + def __setitem__( + self, + crd: Coord, + data: Union[int, float, ir.Value, Numeric, "TensorSSA"], + *, + loc=None, + ip=None, + ) -> None: + """Set tensor elements at specified coordinates. + + Assigns values to tensor elements through direct coordinate access or slice assignment. + For slice assignment, the value must be a TensorSSA with matching shape. + + :param crd: Coordinate or slice specification for tensor element assignment + :type crd: Coord + :param data: Value to assign - can be scalar or TensorSSA for slice assignment + :type data: Union[int, float, ir.Value, Numeric, TensorSSA] + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + + :raises ValueError: If tensor type doesn't support load/store operations + :raises ValueError: If slice assignment value is not a TensorSSA + :raises ValueError: If value type doesn't match tensor element type + :raises NotImplementedError: If value type is not supported + + .. note:: + Sub-byte types like Float4E2M1FN and Float6E3M2FN are not supported for scalar + dereference operations. Attempting to set individual elements of tensors with + these element types will result in errors. + + **Examples:** + + .. code-block:: python + + # Unsupported operations with sub-byte types: + ptr = make_ptr(cutlass.Float4E2M1FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + # The following will raise an error: + tensor[0] = 1.0 # Error: sub-byte scalar dereference not supported + + # Similarly for other sub-byte types: + ptr = make_ptr(cutlass.Float6E3M2FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + tensor[0] = 0.5 # Error: sub-byte scalar dereference not supported + """ + self._check_can_load_store() + + # convert scalar type + if not has_underscore(crd): + self._check_can_dereference() + # First, convert ir.Value to Numeric + if isinstance(data, ir.Value): + data = as_numeric(data) + elif isinstance(data, (int, float, bool)): + data = as_numeric(data) + + if not isinstance(data, Numeric): + raise ValueError(f"unsupported data type: {type(data)}") + + # Implicit upcast to wider type + val = self._cvt_to_dest(data, loc=loc, ip=ip) + if val.type != self.type.value_type: + raise ValueError( + f"type mismatch, store {val.type} to {self.element_type}" + ) + + crd_val = _pack_coord(crd, loc=loc, ip=ip) + _cute_ir.memref_store(self.value, crd_val, val, loc=loc, ip=ip) + else: + if not isinstance(data, TensorSSA): + raise ValueError(f"expects TensorSSA, but got {data}") + + self.__getitem__(crd).store(data, loc=loc, ip=ip) # type: ignore + + @property + def __class__(self) -> Type[Tensor]: + return Tensor + + # Make it behave as if it inherited from ir.Value + @property + @lru_cache_ir() + def type(self) -> ir.Type: + return self.value.type + + @property + def iterator(self) -> Union[Pointer, IntTuple]: + return self._iterator + + @property + def layout(self) -> Layout: + return _cute_ir.get_layout(self.value) + + @property + def shape(self) -> Shape: + return self.layout.shape + + @property + def stride(self) -> Stride: + if isinstance(self.type, _cute_ir.ComposedLayoutType): + raise ValueError(f"can't get stride from composed layout") + return self.layout.stride + + @property + def leading_dim(self) -> Union[int, Tuple[int], None]: + """Get the leading dimension of this Tensor. + + :return: The index or indices of the first mode (from left to right) with stride 1 + :rtype: Union[int, Tuple[int], None] + :returns: + - int: Single leading dimension index if found + - Tuple[int]: Tuple of indices for nested leading dimensions + - None: If no leading dimension is found + + :postcondition: ``get(self.stride(), mode=self.leading_dim()) == 1 if self.leading_dim() != None else True`` + """ + return leading_dim(self.shape, self.stride) + + @property + @lru_cache_ir() + def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: + return self._dtype + + @property + @lru_cache_ir() + def memspace(self) -> AddressSpace: + if isinstance(self.iterator, Pointer): + return self.iterator.memspace + + raise ValueError(f"{self} doesn't have memspace") + + @dsl_user_op + def load(self, *, loc=None, ip=None) -> "TensorSSA": + """Load tensor elements as a vector. + + Loads all elements of the tensor into a vector representation, assuming the tensor + has a static shape and is in a memory space that supports load operations. + + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: Vector representation of tensor elements + :rtype: TensorSSA + + :raises ValueError: If tensor has dynamic layout + :raises ValueError: If tensor memory space doesn't support load operations + """ + if not is_static(self.shape): + raise ValueError("dynamic layout doesn't support load") + + self._check_can_load_store() + + res_vect = _cute_ir.memref_load_vec(self.value, row_major=True, loc=loc, ip=ip) + if self.element_type is Boolean: + assert ( + res_vect.type.element_type == T.i8() + ), f"Boolean tensor must be stored as i8 in memory, but got {res_vect.type.element_type}" + zeros = full_like(self, 0, Int8, loc=loc, ip=ip) + res_vect = arith.cmpi( + arith.CmpIPredicate.ne, res_vect, zeros, loc=loc, ip=ip + ) + return TensorSSA(res_vect, self.shape, self.element_type) + + @dsl_user_op + def store(self, data: "TensorSSA", *, loc=None, ip=None): + """Store vector data into tensor. + + Stores vector data into the tensor, assuming matching shapes and a memory space + that supports store operations. + + :param data: Vector data to store into tensor + :type data: TensorSSA + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + + :raises ValueError: If tensor has dynamic layout + :raises ValueError: If tensor memory space doesn't support store operations + :raises ValueError: If data shape doesn't match tensor shape + """ + if not isinstance(data, TensorSSA): + raise ValueError(f"Expects TensorSSA, but got {type(data)}") + + if not is_static(self.shape): + raise ValueError("Dynamic layout doesn't support vectorized store") + + self._check_can_load_store() + + n_elems = size(self.shape, loc=loc, ip=ip) + if n_elems != size(data.shape, loc=loc, ip=ip): + raise ValueError( + f"lhs and rhs must have the same shape, but got {self.shape} and {data.shape}" + ) + + elem_mlir_type = cutlass_arith.element_type(data.dtype.mlir_type) + if cutlass_arith.is_narrow_precision(elem_mlir_type): + if elem_mlir_type.width * n_elems % 32 != 0: + raise ValueError( + f"narrow precision type must be 32-bit aligned vector, but got {elem_mlir_type} with {n_elems} elements" + ) + + # Implicit upcast to wider type + new_data = self._cvt_to_dest(data, loc=loc, ip=ip) + + return _cute_ir.memref_store_vec( + new_data, self.value, row_major=True, loc=loc, ip=ip + ) + + @dsl_user_op + def fill(self, value: Numeric, *, loc=None, ip=None) -> None: + """Fill tensor with a constant value. + + Fills all elements of the tensor with the specified value, assuming static size + and supported memory space. + + :param value: Value to fill tensor with + :type value: Union[int, float] + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + + :raises NotImplementedError: If tensor has dynamic size + + **Examples:** + + .. code-block:: python + + # Create tensor from numpy array + b = np.random.randn(4, 8).astype(np.float32) + tensor = from_dlpack(b) + + # Fill tensor with constant value + tensor.fill(0.5) # All elements become 0.5 + """ + self._check_can_load_store() + + sz = size(self, loc=loc, ip=ip) + if type(sz) is not int: + raise NotImplementedError(f"dynamic size is not supported: {self.type}") + + # Should we cast to destination type even with narrow cast? + dst_type = self.element_type + value = dst_type(value) + + self[None] = full(self.shape, fill_value=value, dtype=dst_type, loc=loc, ip=ip) + + def _check_can_load_store(self): + if not isinstance(self.type, _cute_ir.MemRefType) or not self.memspace in ( + AddressSpace.rmem, + AddressSpace.smem, + AddressSpace.gmem, + AddressSpace.generic, + ): + raise ValueError(f"{self} doesn't support load and store") + + def _check_can_dereference(self): + # Check for sub-byte types and raise error if needed + if self.element_type.width % 8 != 0 and self.element_type is not Boolean: + raise ValueError( + f"Sub-byte scalar dereference not supported for type {self.element_type}" + ) + + +@dsl_user_op +def print_tensor( + tensor: Union[Tensor, "TensorSSA"], *, verbose: bool = False, loc=None, ip=None +): + """Print content of the tensor in human readable format. + + Outputs the tensor data in a structured format showing both metadata + and the actual data values. The output includes tensor type information, + layout details, and a formatted array representation of the values. + + :param tensor: The tensor to print + :type tensor: Tensor + :param verbose: If True, includes additional debug information in the output + :type verbose: bool + :param loc: Source location where it's called, defaults to None + :type loc: source location, optional + :param ip: Insertion pointer for IR generation, defaults to None + :type ip: insertion pointer, optional + :raises NotImplementedError: If the tensor type doesn't support trivial dereferencing + + **Example output:** + + .. code-block:: text + + tensor(raw_ptr<@..., Float32, generic, align(4)> o (8,5):(5,1), data= + [[-0.4326, -0.5434, 0.1238, 0.7132, 0.8042], + [-0.8462, 0.9871, 0.4389, 0.7298, 0.6948], + [ 0.3426, 0.5856, 0.1541, 0.2923, 0.6976], + [-0.1649, 0.8811, 0.1788, 0.1404, 0.2568], + [-0.2944, 0.8593, 0.4171, 0.8998, 0.1766], + [ 0.8814, 0.7919, 0.7390, 0.4566, 0.1576], + [ 0.9159, 0.7577, 0.6918, 0.0754, 0.0591], + [ 0.6551, 0.1626, 0.1189, 0.0292, 0.8655]]) + """ + if isinstance(tensor, TensorSSA): + tmp = make_fragment(tensor.shape, tensor.dtype) + tmp.store(tensor) + tensor = tmp + + if not isinstance(tensor.type, _cute_ir.MemRefType): + raise NotImplementedError( + f"printing {tensor} is not supported because it doesn't support trivial dereferencing. " + f"Coordinate Tensor will be supported in the future." + ) + + tensor._check_can_load_store() # type: ignore + + if tensor.element_type.is_integer: + signed = tensor.element_type.signed + else: + signed = False + + _cute_ir.print_view(tensor.value, verbose=verbose, is_signed=signed, loc=loc, ip=ip) + + +#################################################################################################### +# +# Core API +# +#################################################################################################### + + +# +# Utilties +# + + +@lru_cache_ir() +def is_integer(a) -> bool: + """Check if an object is static integer or dynamic integer""" + return isinstance(a, (int, Integer)) or ( + isinstance(a, ir.Value) + and isinstance(a.type, (ir.IntegerType, _cute_ir.ConstrainedIntType)) + ) + + +def is_valid_leaf(a) -> bool: + """ + Returns whether `a` has a type that is valid for a CuTe tuple's leaf. + """ + return ( + is_integer(a) + or (a is None) + or isinstance(a, (ScaledBasis, Layout, ComposedLayout)) + ) + + +def is_int_tuple(a) -> bool: + if isinstance(a, tuple): + return all([is_int_tuple(x) for x in a]) + else: + return is_integer(a) + + +def is_static(x: Union[ir.Type, ir.Value, XTuple]) -> bool: + """Check if a value is statically known at compile time. + + In CuTe, static values are those whose values are known at compile time, + as opposed to dynamic values which are only known at runtime. + + :param x: The value to check + :type x: Union[ir.Type, ir.Value, XTuple] + :return: True if the value is static, False otherwise + :rtype: bool + :raises TypeError: If an unsupported type is provided + """ + if isinstance(x, ir.Type): + return _cute_ir.is_static(x) + elif isinstance(x, tuple): + return all(is_static(a) for a in x) + # Can it be a static int? + elif isinstance(x, Numeric): + return False + elif is_dynamic_expression(x): + return _cute_ir.is_static(x.type) + elif isinstance(x, (bool, int, float)) or x is None: + return True + elif isinstance(x, ScaledBasis): + return x.is_static() + else: + raise TypeError(f"unsupported type {x}") + + +def has_underscore(a: XTuple) -> bool: + if type(a) is tuple: + return any([has_underscore(x) for x in a]) + else: + return a is None + + +def has_scaled_basis(a: XTuple) -> bool: + """Check if a tuple or its nested elements contain ScaledBasis objects. + + ScaledBasis objects are fundamental components in CuTe layouts, + representing the basis vectors of coordinate systems. + + :param a: The tuple to check + :type a: XTuple + :return: True if the tuple contains ScaledBasis objects, False otherwise + :rtype: bool + """ + if type(a) is tuple: + return any([has_scaled_basis(x) for x in a]) + else: + return isinstance(a, ScaledBasis) + + +def _tuple_str(t: tuple) -> str: + """ + Constructs a string representation of a python tuple without calling __repr__ on its elements. + """ + + def construct_inner_str(t) -> str: + if not isinstance(t, tuple): + return pretty_str(t) + res = "" + l = len(t) + for i in range(l): + res += pretty_str(t[i]) + if i < l - 1: + res += "," + return res + + res = "(" + construct_inner_str(t) + ")" + return res + + +def pretty_str(arg) -> str: + """ + Constructs a concise readable pretty string. + """ + if isinstance(arg, tuple): + # _tuple_str for tuples + return _tuple_str(arg) + elif arg is None: + # We interpret None as underscores for slicers + return "_" + else: + # Fallback to __str__ + return arg.__str__() + + +@dsl_user_op +def printf(*args, loc=None, ip=None) -> None: + """ + Print a value or a list of values. + + It supports c-style printf format as well: + + .. code-block:: python + + a = cute.make_layout(shape=(10, 10), stride=(10, 1)) + b = cutlass.Float32(1.234) + cute.printf(a, b) + cute.printf("a={}, b={}", a, b) + cute.printf("a={}, b=%.2f", a, b) + + :param args: List of values to print + :type args: list + :param loc: Source location where it's called, defaults to None + :type loc: source location, optional + :param ip: Insertion pointer, defaults to None + :type ip: insertion pointer, optional + :raises ValueError: If no arguments are provided or if an unsupported argument type is passed + """ + + if len(args) == 0: + raise ValueError("expects at least one argument to print") + + if isinstance(args[0], str): + fmt = args[0] + "\n" + args = args[1:] + else: + fmt = "{}" + ", {}" * (len(args) - 1) + "\n" + + def process_arg(arg): + arg0 = arg.value if isinstance(arg, Numeric) else arg + + if isinstance(arg0, ir.Value): + return arg0 + elif isinstance(arg0, bool): + return const(arg0, Boolean) + elif isinstance(arg0, int): + return const(arg0, Int32) + elif isinstance(arg0, float): + return const(arg0, Float32) + elif has_underscore(arg0): + # Assume it's a coordinate + return _pack_coord(arg0) + elif has_scaled_basis(arg0): + # Assume it's a stride + return _pack_stride(arg0) + elif isinstance(arg0, tuple): + # Assume it's an int_tuple + return _pack_int_tuple(arg0) + elif isinstance(arg0, (_Tensor, _Pointer)): + return arg0.value + else: + raise TypeError(f"unsupported argument type in printf, got {type(arg)}") + + args = [process_arg(a) for a in args] + _cute_ir.print_(args, fmt=fmt, loc=loc, ip=ip) + + +@dsl_user_op +def front(input, *, loc=None, ip=None): + """Recursively get the first element of input. + + This function traverses a hierarchical structure (like a layout or tensor) + and returns the first element at the deepest level. It's particularly useful + for accessing the first stride value in a layout to determine properties like + majorness. + + :param input: The hierarchical structure to traverse + :type input: Union[Tensor, Layout, Stride] + :param loc: Source location where it's called, defaults to None + :type loc: source location, optional + :param ip: Insertion pointer for IR generation, defaults to None + :type ip: insertion pointer, optional + :return: The first element at the deepest level of the input structure + :rtype: Union[int, float, bool, ir.Value] + """ + if rank(input) == 1 and depth(input) == 0: + return input + else: + return front(get(input, mode=[0], loc=loc, ip=ip), loc=loc, ip=ip) + + +@dsl_user_op +def is_major(mode, stride: Stride, *, loc=None, ip=None) -> bool: + """ + Check whether a mode in stride is the major mode. + """ + first_stride = front(get(stride, mode=[mode], loc=loc, ip=ip), loc=loc, ip=ip) + if is_dynamic_expression(first_stride): + return False + return True if first_stride == 1 else False + + +def leading_dim(shape: Shape, stride: Stride) -> Union[int, Tuple[int, ...], None]: + """ + Find the leading dimension of a shape and stride. + + :param shape: The shape of the tensor or layout + :type shape: Shape + :param stride: The stride of the tensor or layout + :type stride: Stride + :return: The leading dimension index or indices + :rtype: Union[int, Tuple[int, ...], None] + + The return value depends on the stride pattern: + + * If a single leading dimension is found, returns an integer index + * If nested leading dimensions are found, returns a tuple of indices + * If no leading dimension is found, returns None + """ + + def pred_fn(val, pos): + # skip dynamic values which can't be compared + # find the candidate target val, stride at this position is 1 + if (not is_dynamic_expression(val)) and (val == 1): + # extract the shape at this position + mode = [pos] if isinstance(pos, int) else list(pos) + s = get(shape, mode) + if is_dynamic_expression(s) or s != 1: + # shape at this position is dynamic value or not 1 + # we found the leading dimension + return True + return False + + return find_if(stride, pred_fn=pred_fn) + + +@dsl_user_op +def find_if( + t: Union[tuple, ir.Value, int], + pred_fn: Callable[[int, Tuple[int, ...]], bool], + *, + loc=None, + ip=None, +) -> Union[int, Tuple[int, ...], None]: + """Find the first position in t where pred_fn(val, pos) returns True. + + :param t: The search space + :type t: Union[tuple, ir.Value, int] + :param pred_fn: A callable object (lambda, function, etc.) that predicates the value and position in t. + It takes the current leaf value and position, returns True if the value or position is satisfied. + :type pred_fn: Callable[[int, Tuple[int, ...]], bool] + :return: Index if found at top level, tuple of indices showing nested position, or None if not found + :rtype: Union[int, Tuple[int, ...], None] + + **Examples:** + + .. code-block:: python + + # Find the first position of x in t + t = (3, 4) + find_if(t, pred_fn=lambda val, pos: val == x) + + .. code-block:: python + + # find the leading dimension + shape = (3, 4) + stride = (4, 1) + # Find value 1 in stride where the corresponding shape is not 1 + def pred_fn(val, pos): + mode = [pos] if isinstance(pos, int) else list(pos) + return val == 1 and get(shape, mode) != 1 + find_if(stride, pred_fn=pred_fn) + """ + + def _find_if_impl(curr, pos, *, loc=None, ip=None): + if isinstance(curr, tuple): + # Recursively search nested tuple + for i in range(rank(curr)): + sub_curr = get(curr, mode=[i], loc=loc, ip=ip) + sub_pos = (pos, i) if isinstance(pos, int) else pos + (i,) + res_pos = _find_if_impl(sub_curr, sub_pos, loc=loc, ip=ip) + if res_pos is not None: + return res_pos + else: + # For leaf values, check if it matches x + if pred_fn(curr, pos): + return pos + return None + + def _check_pred_fn(): + if not callable(pred_fn): + raise TypeError(f"pred_fn must be callable, but got {type(pred_fn)}") + signature = inspect.signature(pred_fn) + if len(signature.parameters) != 2: + raise ValueError( + f"pred_fn must have two parameters (value, pos), but got {len(signature.parameters)}" + ) + + _check_pred_fn() + + for i in range(rank(t)): + curr = get(t, mode=[i], loc=loc, ip=ip) + res_pos = _find_if_impl(curr, i, loc=loc, ip=ip) + if res_pos is not None: + return res_pos + return None + + +@dsl_user_op +def find( + t: Union[tuple, ir.Value, int], + x: int, + *, + loc=None, + ip=None, +) -> Union[int, Tuple[int, ...], None]: + """Find the first position of a value ``x`` in a hierarchical structure ``t``. + + Searches for the first occurrence of x in t, optionally excluding positions + where a comparison value matches. The search can traverse nested structures + and returns either a single index or a tuple of indices for nested positions. + + :param t: The search space + :type t: Union[tuple, ir.Value, int] + :param x: The static integer x to search for + :type x: int + :return: Index if found at top level, tuple of indices showing nested position, or None if not found + :rtype: Union[int, Tuple[int, ...], None] + """ + if not isinstance(x, int): + raise TypeError(f"find() requires a static x to search for, but got {x}") + + def pred_fn(val, pos): + # Skip dynamic values which can't be compared + return not is_dynamic_expression(val) and val == x + + return find_if(t, pred_fn=pred_fn, loc=loc, ip=ip) + + +def transform_leaf(f, *args): + """ + Apply a function to the leaf nodes of nested tuple structures. + + This function traverses nested tuple structures in parallel and applies the function f + to corresponding leaf nodes. All input tuples must have the same nested structure. + + :param f: Function to apply to leaf nodes + :type f: Callable + :param args: One or more nested tuple structures with matching profiles + :return: A new nested tuple with the same structure as the inputs, but with leaf values transformed by f + :raises TypeError: If the input tuples have different nested structures + + Example: + + .. code-block:: python + + >>> transform_leaf(lambda x: x + 1, (1, 2)) + (2, 3) + >>> transform_leaf(lambda x, y: x + y, (1, 2), (3, 4)) + (4, 6) + >>> transform_leaf(lambda x: x * 2, ((1, 2), (3, 4))) + ((2, 4), (6, 8)) + """ + if all(isinstance(t, tuple) for t in args): + return tuple(transform_leaf(f, *_args) for _args in zip(*args)) + elif all(not isinstance(t, tuple) for t in args): + return f(*args) + else: + raise TypeError(f"profile of input tuples doesn't match: {args}") + + +@dsl_user_op +def assume(src, divby=None, *, loc=None, ip=None): + if divby is None: + return src + + if isinstance(src, Integer): + width = type(src).width + src_val = src.ir_value() + else: + width = src.type.width + src_val = src + + res_ty = _cute_ir.ConstrainedIntType.get(divby, width) + assumed_val = _cute_ir.assume(res_ty, src_val, loc=loc, ip=ip) + return type(src)(IntValue(_pack_int_tuple(assumed_val, loc=loc, ip=ip))) + + +@dsl_user_op +def make_swizzle(b, m, s, *, loc=None, ip=None): + # canonicalize to <0, 4, 3> for identity swizzle (as compiler assumes <0, 4, 3>) + if b == 0: + m, s = 4, 3 + ty = ir.Type.parse(f'!cute.swizzle<"S<{b},{m},{s}>">') + return Swizzle(_cute_ir.static(ty, loc=loc, ip=ip)) + + +# +# Tuple API (also used by layouts and tensors) +# + + +def depth(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: + """Returns the depth (nesting level) of a tuple, layout, or tensor. + + The depth of a tuple is the maximum depth of its elements plus 1. + For an empty tuple, the depth is 1. For layouts and tensors, the depth + is determined by the depth of their shape. For non-tuple values (e.g., integers), + the depth is considered 0. + + :param a: The object whose depth is to be determined + :type a: Union[XTuple, Layout, ComposedLayout, Tensor, Any] + :return: The depth of the input object + :rtype: int + + Example: + + .. code-block:: python + + >>> depth(1) + 0 + >>> depth((1, 2)) + 1 + >>> depth(((1, 2), (3, 4))) + 2 + """ + if type(a) is tuple: + if not a: + return 1 + return max(depth(x) for x in a) + 1 + elif isinstance(a, (Layout, ComposedLayout, Tensor)): + return depth(a.shape) + else: + return 0 + + +@lru_cache_ir() +def rank(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: + """Returns the rank (dimensionality) of a tuple, layout, or tensor. + + The rank of a tuple is its length. For layouts and tensors, the rank is + determined by the rank of their shape. For non-tuple values (e.g., integers), + the rank is considered 1 for convenience. + + :param a: The object whose rank is to be determined + :type a: Union[XTuple, Layout, ComposedLayout, Tensor, Any] + :return: The rank of the input object + :rtype: int + + This function is used in layout algebra to determine the dimensionality + of tensors and layouts for operations like slicing and evaluation. + """ + if isinstance(a, tuple): + return len(a) + elif isinstance(a, (Layout, ComposedLayout, Tensor)): + return rank(a.shape) + elif depth(a) == 0: + return 1 + else: + raise TypeError(f"unsupported type in rank, got {type(a)}") + + +def is_congruent( + a: Union[XTuple, Layout, ComposedLayout, Tensor], + b: Union[XTuple, Layout, ComposedLayout, Tensor], +) -> bool: + """ + Returns whether a is congruent to b. + + Congruence is an equivalence relation between hierarchical structures. + + Two objects are congruent if: + * They have the same rank, AND + * They are both non-tuple values, OR + * They are both tuples AND all corresponding elements are congruent. + + Congruence requires type matching at each level -- scalar values match with + scalar values, and tuples match with tuples of the same rank. + + :param a: First object to compare + :type a: Union[XTuple, Layout, ComposedLayout, Tensor] + :param b: Second object to compare + :type b: Union[XTuple, Layout, ComposedLayout, Tensor] + :return: True if a and b are congruent, False otherwise + :rtype: bool + """ + if isinstance(a, (Layout, ComposedLayout, Tensor)): + a = a.shape + if isinstance(b, (Layout, ComposedLayout, Tensor)): + b = b.shape + if isinstance(a, tuple) and isinstance(b, tuple): + return (len(a) == len(b)) and all(is_congruent(x, y) for x, y in zip(a, b)) + if isinstance(a, tuple) or isinstance(b, tuple): + return False + return True + + +def is_weakly_congruent( + a: Union[XTuple, Layout, ComposedLayout, Tensor], + b: Union[XTuple, Layout, ComposedLayout, Tensor], +) -> bool: + """ + Returns whether a is weakly congruent to b. + + Weak congruence is a partial order on hierarchical structures. + + Object X is weakly congruent to object Y if: + * X is a non-tuple value, OR + * X and Y are both tuples of the same rank AND all corresponding elements are weakly congruent. + + Weak congruence allows scalar values to match with tuples, making it useful + for determining whether an object has a hierarchical structure "up to" another. + + :param a: First object to compare + :type a: Union[XTuple, Layout, ComposedLayout, Tensor] + :param b: Second object to compare + :type b: Union[XTuple, Layout, ComposedLayout, Tensor] + :return: True if a and b are weakly congruent, False otherwise + :rtype: bool + """ + if isinstance(a, (Layout, ComposedLayout, Tensor)): + a = a.shape + if isinstance(b, (Layout, ComposedLayout, Tensor)): + b = b.shape + if not isinstance(a, tuple): + return True + if isinstance(a, tuple) and isinstance(b, tuple): + return (len(a) == len(b)) and all( + is_weakly_congruent(x, y) for x, y in zip(a, b) + ) + if isinstance(a, tuple) or isinstance(b, tuple): + return False + return True + + +@overload +def get(input: Shape, mode, *, loc=None, ip=None) -> Shape: ... +@overload +def get(input: Stride, mode, *, loc=None, ip=None) -> Stride: ... +@overload +def get(input: Coord, mode, *, loc=None, ip=None) -> Coord: ... +@overload +def get(input: IntTuple, mode, *, loc=None, ip=None) -> IntTuple: ... +@overload +def get(input: Tile, mode, *, loc=None, ip=None) -> Tile: ... +@overload +def get(input: Layout, mode, *, loc=None, ip=None) -> Layout: ... +@overload +def get(input: ComposedLayout, mode, *, loc=None, ip=None) -> ComposedLayout: ... + + +@dsl_user_op +def get(input, mode: List[int], *, loc=None, ip=None): + """Extract a specific element or sub-layout from a layout or tuple. + + This function recursively traverses the input according to the mode indices, + extracting the element at the specified path. For layouts, this operation + corresponds to extracting a specific sub-layout. + + :param input: The input layout or tuple to extract from + :type input: Layout, ComposedLayout, tuple + :param mode: Indices specifying the path to traverse for extraction + :type mode: List[int] + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The extracted element or sub-layout + :rtype: Layout, ComposedLayout, or element type + :raises ValueError: If any index in mode is out of range + :raises TypeError: If mode contains non-integer elements or if input has unsupported type + + :postcondition: ``get(t, mode=find(x,t)) == x if find(x,t) != None else True`` + + **Examples:** + + .. code-block:: python + + layout = make_layout(((4, 8), (16, 1), 8), stride=((1, 4), (32, 0), 512)) + sub_layout = get(layout, mode=[0, 1]) # 8:4 + sub_layout = get(layout, mode=[1]) # (16, 1):(32, 0) + """ + # Empty mode returns input and terminates the recursive call + if not mode: + return input + + if rank(input) <= mode[0]: + raise ValueError( + f"elements in mode must be less than rank({input}), got {mode}" + ) + + if depth(input) == 0: + return input + elif isinstance(input, tuple): + if not isinstance(mode[0], int): + raise TypeError( + f"invalid element in mode, expects int, got {type(mode[0])}" + ) + return get(input[mode[0]], mode=mode[1:]) + else: + if not isinstance(input, (Layout, ComposedLayout)): + raise TypeError(f"unsupported type of input, got {type(input)}") + return _cute_ir.get( + input.type.get_op_res_type(mode=mode), input, mode=mode, loc=loc, ip=ip + ) + + +@overload +def select(input: Shape, mode, *, loc=None, ip=None) -> Shape: ... +@overload +def select(input: Stride, mode, *, loc=None, ip=None) -> Stride: ... +@overload +def select(input: Coord, mode, *, loc=None, ip=None) -> Coord: ... +@overload +def select(input: IntTuple, mode, *, loc=None, ip=None) -> IntTuple: ... +@overload +def select(input: Tile, mode, *, loc=None, ip=None) -> Tile: ... +@overload +def select(input: Layout, mode, *, loc=None, ip=None) -> Layout: ... +@overload +def select(input: ComposedLayout, mode, *, loc=None, ip=None) -> ComposedLayout: ... + + +@dsl_user_op +def select(input, mode: List[int], *, loc=None, ip=None): + """Select modes from input. + + :param input: Input to select from + :type input: Layout, ComposedLayout, tuple + :param mode: Indices specifying which dimensions or elements to select + :type mode: List[int] + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: A new instance with selected dimensions/elements + :rtype: Layout, ComposedLayout, tuple + :raises ValueError: If any index in mode is out of range + :raises TypeError: If the input type is invalid + + **Examples:** + + .. code-block:: python + + # Select specific dimensions from a layout + layout = make_layout((4, 8, 16), stride=(32, 4, 1)) + selected = select(layout, mode=[0, 2]) # Select mode 0 and mode 2 + # Result: (4, 16):(32, 1) + + # Select elements from a tuple + t = (1, 2, 3, 4, 5) + selected = select(t, mode=[0, 2, 4]) # Select mode 0, mode 2, and mode 4 + # Result: (1, 3, 5) + """ + if any((not isinstance(i, int)) or (i >= rank(input)) for i in mode): + raise ValueError( + f"invalid mode element for input of rank {rank(input)}, got {mode=}" + ) + + if isinstance(input, tuple): + return tuple(input[i] for i in mode) + + if not isinstance(input, (Layout, ComposedLayout)): + raise TypeError(f"unsupported type of input, got {type(input)}") + + return _cute_ir.select(input, mode=mode, loc=loc, ip=ip) + + +@overload +def group_modes(input: Shape, begin: int, end: int, *, loc=None, ip=None) -> Shape: ... +@overload +def group_modes( + input: Stride, begin: int, end: int, *, loc=None, ip=None +) -> Stride: ... +@overload +def group_modes(input: Coord, begin: int, end: int, *, loc=None, ip=None) -> Coord: ... +@overload +def group_modes( + input: IntTuple, begin: int, end: int, *, loc=None, ip=None +) -> IntTuple: ... +@overload +def group_modes(input: Tile, begin: int, end: int, *, loc=None, ip=None) -> Tile: ... +@overload +def group_modes( + input: Layout, begin: int, end: int, *, loc=None, ip=None +) -> Layout: ... +@overload +def group_modes( + input: ComposedLayout, begin: int, end: int, *, loc=None, ip=None +) -> ComposedLayout: ... +@overload +def group_modes( + input: Tensor, begin: int, end: int, *, loc=None, ip=None +) -> Tensor: ... + + +@dsl_user_op +def group_modes(input, begin: int, end: int = -1, *, loc=None, ip=None): + """Group modes of a hierarchical tuple or layout into a single mode. + + This function groups a range of modes from the input object into a single mode, + creating a hierarchical structure. For tuples, it creates a nested tuple containing + the specified range of elements. For layouts and other CuTe objects, it creates + a hierarchical representation where the specified modes are grouped together. + + :param input: Input object to group modes from (layout, tuple, etc.) + :type input: Layout, ComposedLayout, tuple, Shape, Stride, etc. + :param beg: Beginning index of the range to group (inclusive) + :type beg: int + :param end: Ending index of the range to group (exclusive) + :type end: int + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: A new object with the specified modes grouped + :rtype: Same type as input with modified structure + + **Examples:** + + .. code-block:: python + + # Group modes in a tuple + t = (2, 3, 4, 5) + grouped = group_modes(t, 1, 3) # (2, (3, 4), 5) + + # Group modes in a layout + layout = make_layout((2, 3, 4, 5)) + grouped_layout = group_modes(layout, 1, 3) # Layout with shape (2, (3, 4), 5) + + # Group modes in a shape + shape = make_shape(2, 3, 4, 5) + grouped_shape = group_modes(shape, 0, 2) # Shape ((2, 3), 4, 5) + """ + if depth(input) == 0 and is_integer(input): + return (input,) + if isinstance(input, tuple): + return (*input[:begin], (input[begin:end]), *input[end:]) + return _cute_ir.group_modes( + input.value if isinstance(input, Tensor) else input, begin, end, loc=loc, ip=ip + ) + + +@overload +def slice_(src: Shape, coord: Coord, *, loc=None, ip=None) -> Shape: ... +@overload +def slice_(src: Stride, coord: Coord, *, loc=None, ip=None) -> Stride: ... +@overload +def slice_(src: Coord, coord: Coord, *, loc=None, ip=None) -> Coord: ... +@overload +def slice_(src: IntTuple, coord: Coord, *, loc=None, ip=None) -> IntTuple: ... +@overload +def slice_(src: Tile, coord: Coord, *, loc=None, ip=None) -> Tile: ... +@overload +def slice_(src: Layout, coord: Coord, *, loc=None, ip=None) -> Layout: ... +@overload +def slice_( + src: ComposedLayout, coord: Coord, *, loc=None, ip=None +) -> ComposedLayout: ... +@overload +def slice_(src: Tensor, coord: Coord, *, loc=None, ip=None) -> Tensor: ... + + +@dsl_user_op +def slice_(src, coord: Coord, *, loc=None, ip=None): + """Perform a slice operation on a source object using the given coordinate. + + This function implements CuTe's slicing operation which extracts a subset of elements + from a source object (tensor, layout, etc.) based on a coordinate pattern. The slice + operation preserves the structure of the source while selecting specific elements. + + :param src: Source object to be sliced (tensor, layout, tuple, etc.) + :type src: Union[Tensor, Layout, IntTuple, Value] + :param coord: Coordinate pattern specifying which elements to select + :type coord: Coord + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new object containing the sliced elements + :rtype: Union[Tensor, Layout, IntTuple, tuple] + :raises ValueError: If the coordinate pattern is incompatible with source + + **Examples:** + + .. code-block:: python + + # Layout slicing + layout = make_layout((4,4)) + + # Select 1st index of first mode and keep all elements in second mode + sub_layout = slice_(layout, (1, None)) + + .. code-block:: python + + # Basic tensor slicing + tensor = make_tensor(...) # Create a 2D tensor + + # Select 1st index of first mode and keep all elements in second mode + sliced = slice_(tensor, (1, None)) + + .. code-block:: python + + # Select 2nd index of second mode and keep all elements in first mode + sliced = slice_(tensor, (None, 2)) + + Note: + - `None` represents keeping all elements in that mode + - Slicing preserves the layout/structure of the original object + - Can be used for: + * Extracting sub-tensors/sub-layouts + * Creating views into data + * Selecting specific patterns of elements + """ + + def lift_slice(a, b): + if isinstance(a, tuple): + if (not isinstance(b, tuple)) or (len(a) != len(b)): + raise ValueError("coord must be weakly congruent to src in slice_") + return reduce( + lambda p, q: p + q, (lift_slice(x, y) for x, y in zip(a, b)), () + ) + elif a is None: + return (b,) + else: + return () + + if is_integer(src) or isinstance(src, tuple): + if isinstance(coord, tuple): + if (not isinstance(src, tuple)) or (len(coord) != len(src)): + raise ValueError("coord must be weakly congruent to src in slice_") + return reduce( + lambda p, q: p + q, (lift_slice(x, y) for x, y in zip(coord, src)), () + ) + elif coord is None: + return src + else: + return () + + res_type = None + if isinstance(src, Tensor): + res_type = src.element_type + src = src.value + coord_val = _pack_coord(coord, loc=loc, ip=ip) + res = _cute_ir.slice(input=src, coord=coord_val, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + + +@overload +def dice(src: Shape, coord: Coord, *, loc=None, ip=None) -> Shape: ... +@overload +def dice(src: Stride, coord: Coord, *, loc=None, ip=None) -> Stride: ... +@overload +def dice(src: Coord, coord: Coord, *, loc=None, ip=None) -> Coord: ... +@overload +def dice(src: IntTuple, coord: Coord, *, loc=None, ip=None) -> IntTuple: ... +@overload +def dice(src: Tile, coord: Coord, *, loc=None, ip=None) -> Tile: ... +@overload +def dice(src: Layout, coord: Coord, *, loc=None, ip=None) -> Layout: ... +@overload +def dice(src: ComposedLayout, coord: Coord, *, loc=None, ip=None) -> ComposedLayout: ... + + +@dsl_user_op +@lru_cache_ir() +def dice(src, dicer, *, loc=None, ip=None): + """Keep modes in input when it is paired with an integer in dicer. + + This function performs dicing operation on the input based on the dicer coordinate. + Dicing is a fundamental operation in CuTe that allows selecting specific modes from + a tensor or layout based on a coordinate pattern. + + :param dicer: A static coordinate indicating how to dice the input + :type dicer: Coord + :param input: The operand to be diced on + :type input: Union[IntTuple, Shape, Stride, Coord, Layout, ComposedLayout] + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: The diced result with selected modes from the input + :rtype: Union[IntTuple, Shape, Stride, Coord, Layout, ComposedLayout] + :raises TypeError: If dicer has an unsupported type + :raises ValueError: If input is not provided + + **Examples:** + + .. code-block:: python + + # Basic dicing of a layout + layout = make_layout((32,16,8)) + + # Keep only first and last modes + diced = dice((1,None,1), layout) + + Note: + - The dicer coordinate must be static + - Use underscore (_) to remove a mode + """ + if not is_static(dicer): + raise ValueError(f"expects dicer to be static, but got {dicer}") + + def lift_dice(a, b): + if isinstance(a, tuple): + if (not isinstance(b, tuple)) or (len(a) != len(b)): + raise ValueError("dicer must be weakly congruent to input in dice") + return reduce( + lambda p, q: p + q, (lift_dice(x, y) for x, y in zip(a, b)), () + ) + elif a is None: + return () + else: + return (b,) + + if is_integer(src) or isinstance(src, tuple): + if isinstance(dicer, tuple): + if (not isinstance(src, tuple)) or (len(dicer) != len(src)): + raise ValueError("dicer must be weakly congruent to src in dice") + return reduce( + lambda p, q: p + q, (lift_dice(x, y) for x, y in zip(dicer, src)), () + ) + elif dicer is None: + return () + else: + return src + + dicer_val = _pack_coord(dicer, loc=loc, ip=ip) + return _cute_ir.dice(src, dicer_val.type.attribute, loc=loc, ip=ip) + + +def wrap(x) -> tuple: + """ + Wraps the input into a tuple if not a tuple. + """ + if isinstance(x, tuple): + return x + return (x,) + + +def _extend(func, input, elem, up_to_rank, loc, ip): + if input is None: + raise ValueError(f"No input provided for input") + + if isinstance(input, (Layout, ComposedLayout)): + if elem is None: + elem = make_layout(1) + elif not isinstance(elem, Layout): + raise TypeError(f"Input type of elem ({type(elem)}) is not accepted!") + N = rank(input) + 1 if up_to_rank is None else up_to_rank + return func(N, input, elem, loc=loc, ip=ip) + + if is_valid_leaf(input) or isinstance(input, tuple): + if elem is None: + elem = 1 + if (not isinstance(elem, tuple)) and (not is_valid_leaf(elem)): + raise TypeError(f"Input type of elem ({type(elem)}) is not accepted!") + + input = wrap(input) + repeat_cnt = 1 if up_to_rank is None else up_to_rank - rank(input) + if repeat_cnt == 0: + return input + elif repeat_cnt < 0: + raise ValueError(f"up_to_rank must be >= rank(input)") + else: + if func is _cute_ir.prepend_to_rank: + return (elem,) * repeat_cnt + input + else: + return input + (elem,) * repeat_cnt + + raise TypeError(f"invalid type for input, got {type(input)}") + + +@overload +def prepend( + input: Shape, elem: Shape, up_to_rank=None, *, loc=None, ip=None +) -> Shape: ... +@overload +def prepend( + input: Stride, elem: Stride, up_to_rank=None, *, loc=None, ip=None +) -> Stride: ... +@overload +def prepend( + input: Coord, elem: Coord, up_to_rank=None, *, loc=None, ip=None +) -> Coord: ... +@overload +def prepend( + input: IntTuple, elem: IntTuple, up_to_rank=None, *, loc=None, ip=None +) -> IntTuple: ... +@overload +def prepend(input: Tile, elem: Tile, up_to_rank=None, *, loc=None, ip=None) -> Tile: ... +@overload +def prepend( + input: Layout, elem: Layout, up_to_rank=None, *, loc=None, ip=None +) -> Layout: ... +@overload +def prepend( + input: ComposedLayout, elem: Layout, up_to_rank=None, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def prepend(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): + """Extend input to rank up_to_rank by prepending elem in front of input. + + This function extends the input object by prepending elements to reach a desired rank. + It supports various CuTe types including shapes, layouts, tensors etc. + + :param input: Source to be prepended to + :type input: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] + :param elem: Element to prepend to input + :type elem: Union[Shape, Stride, Coord, IntTuple, Tile, Layout] + :param up_to_rank: The target rank after extension, defaults to None + :type up_to_rank: Union[None, int], optional + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint] + :return: The extended result with prepended elements + :rtype: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] + :raises ValueError: If up_to_rank is less than input's current rank + :raises TypeError: If input or elem has unsupported type + + **Examples:** + + .. code-block:: python + + # Prepend to a Shape + shape = (4,4) + prepend(shape, 2) # Returns (2,4,4) + + # Prepend to a Layout + layout = make_layout((8,8)) + prepend(layout, make_layout((2,))) # Returns (2,8,8):(1,1,8) + + # Prepend with target rank + coord = (1,1) + prepend(coord, 0, up_to_rank=4) # Returns (0,0,1,1) + """ + return _extend(_cute_ir.prepend_to_rank, input, elem, up_to_rank, loc=loc, ip=ip) + + +@overload +def append( + input: Shape, elem: Shape, up_to_rank=None, *, loc=None, ip=None +) -> Shape: ... +@overload +def append( + input: Stride, elem: Stride, up_to_rank=None, *, loc=None, ip=None +) -> Stride: ... +@overload +def append( + input: Coord, elem: Coord, up_to_rank=None, *, loc=None, ip=None +) -> Coord: ... +@overload +def append( + input: IntTuple, elem: IntTuple, up_to_rank=None, *, loc=None, ip=None +) -> IntTuple: ... +@overload +def append(input: Tile, elem: Tile, up_to_rank=None, *, loc=None, ip=None) -> Tile: ... +@overload +def append( + input: Layout, elem: Layout, up_to_rank=None, *, loc=None, ip=None +) -> Layout: ... +@overload +def append( + input: ComposedLayout, elem: Layout, up_to_rank=None, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def append(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): + """Extend input to rank up_to_rank by appending elem to the end of input. + + This function extends the input object by appending elements to reach a desired rank. + It supports various CuTe types including shapes, layouts, tensors etc. + + :param input: Source to be appended to + :type input: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] + :param elem: Element to append to input + :type elem: Union[Shape, Stride, Coord, IntTuple, Tile, Layout] + :param up_to_rank: The target rank after extension, defaults to None + :type up_to_rank: Union[None, int], optional + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint] + :return: The extended result with appended elements + :rtype: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] + :raises ValueError: If up_to_rank is less than input's current rank + :raises TypeError: If input or elem has unsupported type + + **Examples:** + + .. code-block:: python + + # Append to a Shape + shape = (4,4) + append(shape, 2) # Returns (4,4,2) + + # Append to a Layout + layout = make_layout((8,8)) + append(layout, make_layout((2,))) # Returns (8,8,2):(1,8,1) + + # Append with target rank + coord = (1,1) + append(coord, 0, up_to_rank=4) # Returns (1,1,0,0) + + Note: + - The function preserves the structure of the input while extending it + - Can be used to extend tensors, layouts, shapes and other CuTe types + - When up_to_rank is specified, fills remaining positions with elem + - Useful for tensor reshaping and layout transformations + """ + return _extend(_cute_ir.append_to_rank, input, elem, up_to_rank, loc=loc, ip=ip) + + +@dsl_user_op +def prepend_ones( + t: Tensor, up_to_rank: Union[None, int] = None, *, loc=None, ip=None +) -> Tensor: + return make_tensor( + t.iterator, prepend(t.layout, make_layout(1), up_to_rank), loc=loc, ip=ip + ) + + +@dsl_user_op +def append_ones( + t: Tensor, up_to_rank: Union[None, int] = None, *, loc=None, ip=None +) -> Tensor: + return make_tensor( + t.iterator, append(t.layout, make_layout(1), up_to_rank), loc=loc, ip=ip + ) + + +def repeat_like(x, target): + """Creates an object congruent to target and filled with x. + + This function recursively creates a nested tuple structure that matches the structure + of the target, with each leaf node filled with the value x. + + :param x: The value to fill the resulting structure with + :type x: Any + :param target: The structure to mimic + :type target: Union[tuple, Any] + :return: A structure matching target but filled with x + :rtype: Union[tuple, Any] + + **Examples:** + + .. code-block:: python + + repeat_like(0, (1, 2, 3)) # Returns (0, 0, 0) + repeat_like(1, ((1, 2), 3)) # Returns ((1, 1), 1) + repeat_like(2, 5) # Returns 2 + """ + if not isinstance(target, tuple): + return x + if not target: + return () + if len(target) == 1: + return (repeat_like(x, target[0]),) + return tuple(repeat_like(x, t) for t in target) + + +def flatten_to_tuple(a: Union[IntTuple, Coord, Shape, Stride]) -> tuple: + """Flattens a potentially nested tuple structure into a flat tuple. + + This function recursively traverses the input structure and flattens it into + a single-level tuple, preserving the order of elements. + + :param a: The structure to flatten + :type a: Union[IntTuple, Coord, Shape, Stride] + :return: A flattened tuple containing all elements from the input + :rtype: tuple + + **Examples:** + + .. code-block:: python + + flatten_to_tuple((1, 2, 3)) # Returns (1, 2, 3) + flatten_to_tuple(((1, 2), 3)) # Returns (1, 2, 3) + flatten_to_tuple((1, (2, (3,)))) # Returns (1, 2, 3) + """ + if not isinstance(a, tuple): + return wrap(a) + else: + return tuple(chain.from_iterable(tuple(flatten_to_tuple(x) for x in a))) + + +@overload +def flatten(a: Union[IntTuple, Coord, Shape, Stride]) -> IntTuple: ... +@overload +def flatten(a: Tensor) -> Tensor: ... +@overload +def flatten(a: Layout) -> Layout: ... + + +def flatten(a): + """Flattens a CuTe data structure into a simpler form. + + For tuples, this function flattens the structure into a single-level tuple. + For layouts, it returns a new layout with flattened shape and stride. + For tensors, it returns a new tensor with flattened layout. + For other types, it returns the input unchanged. + + :param a: The structure to flatten + :type a: Union[IntTuple, Coord, Shape, Stride, Layout, Tensor] + :return: The flattened structure + :rtype: Union[tuple, Any] + + **Examples:** + + .. code-block:: python + + flatten((1, 2, 3)) # Returns (1, 2, 3) + flatten(((1, 2), (3, 4))) # Returns (1, 2, 3, 4) + flatten(5) # Returns 5 + flatten(Layout(shape, stride)) # Returns Layout(flatten(shape), flatten(stride)) + flatten(Tensor(layout)) # Returns Tensor(flatten(layout)) + + """ + if isinstance(a, Tensor): + return make_tensor(a.iterator, flatten(a.layout)) + elif isinstance(a, Layout): + return make_layout(flatten(a.shape), stride=flatten(a.stride)) + elif isinstance(a, tuple): + return flatten_to_tuple(a) + else: + return a + + +def unflatten( + sequence: Union[Tuple[Any, ...], List[Any], Iterable[Any]], profile: XTuple +) -> XTuple: + """Unflatten a flat tuple into a nested tuple structure according to a profile. + + This function transforms a flat sequence of elements into a nested tuple structure + that matches the structure defined by the profile parameter. It traverses the profile + structure and populates it with elements from the sequence. + + sequence must be long enough to fill the profile. Raises RuntimeError if it is not. + + :param sequence: A flat sequence of elements to be restructured + :type sequence: Union[Tuple[Any, ...], List[Any], Iterable[Any]] + :param profile: A nested tuple structure that defines the shape of the output + :type profile: XTuple + :return: A nested tuple with the same structure as profile but containing elements from sequence + :rtype: XTuple + + Example: + >>> unflatten([1, 2, 3, 4], ((0, 0), (0, 0))) + ((1, 2), (3, 4)) + """ + + def _make_generator(): + for element in sequence: + yield element + + xs = _make_generator() + return transform_leaf(lambda _: next(xs), profile) + + +@dsl_user_op +def elem_less( + lhs: Union[Shape, IntTuple, Coord], + rhs: Union[Shape, IntTuple, Coord], + *, + loc=None, + ip=None, +): + lhs_val = _pack_coord(lhs, loc=loc, ip=ip) + rhs_val = _pack_coord(rhs, loc=loc, ip=ip) + return Boolean(_cute_ir.elem_less(lhs_val, rhs_val, loc=loc, ip=ip)) + + +@overload +def filter_zeros( + input: Layout, *, target_profile=None, loc=None, ip=None +) -> Layout: ... +@overload +def filter_zeros( + input: Tensor, *, target_profile=None, loc=None, ip=None +) -> Tensor: ... + + +@dsl_user_op +def filter_zeros(input, *, target_profile=None, loc=None, ip=None): + """Filter out zeros from a layout or tensor. + + This function removes zero-stride dimensions from a layout or tensor. + Refer to https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md + for more layout algebra operations. + + :param input: The input layout or tensor to filter + :type input: Layout or Tensor + :param target_profile: Target profile for the filtered result, defaults to None + :type target_profile: optional + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The filtered layout or tensor with zeros removed + :rtype: Layout or Tensor + :raises TypeError: If input is not a Layout or Tensor + """ + if not isinstance(input, (Layout, Tensor)): + raise TypeError(f"Expect layout or tensor as input but got {type(input)=}") + if isinstance(input, Tensor): + input = input.value + return _cute_ir.filter_zeros(input, target_profile=target_profile, loc=loc, ip=ip) + + +@dsl_user_op +def filter(input: Union[Layout, Tensor], *, loc=None, ip=None): + """Filter a layout or tensor. + + This function filters a layout or tensor according to CuTe's filtering rules. + + :param input: The input layout or tensor to filter + :type input: Layout or Tensor + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The filtered layout or tensor + :rtype: Layout or Tensor + :raises TypeError: If input is not a Layout or Tensor + """ + if not isinstance(input, (Layout, Tensor)): + raise TypeError(f"Expect layout or tensor as input but got {type(input)=}") + if isinstance(input, _Tensor): + input = input.value + return _cute_ir.filter(input, loc=loc, ip=ip) + + +@dsl_user_op +def product(a: Union[IntTuple, Shape], *, loc=None, ip=None): + """Return product of the given IntTuple or Shape. + + Computes the product of all elements in the input tuple or shape. + Returns static value if type is static. + + :param a: The input tuple or shape + :type a: IntTuple or Shape + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: Static product of IntTuple or Shape if static, otherwise a Value + :rtype: int or Value + :raises TypeError: If input is not an IntTuple or Shape + """ + if is_integer(a): + return a + if isinstance(a, tuple): + a_val = _pack_int_tuple(a, loc=loc, ip=ip) + res = _cute_ir.tuple_product(a_val, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + else: + raise TypeError(f"expects IntTuple or Shape, but got {type(a)}") + + +@overload +def product_like( + a: IntTuple, target_profile: XTuple, *, loc=None, ip=None +) -> IntTuple: ... +@overload +def product_like(a: Shape, target_profile: XTuple, *, loc=None, ip=None) -> Shape: ... + + +@dsl_user_op +def product_like( + a: Union[IntTuple, Shape], target_profile: XTuple, *, loc=None, ip=None +): + """Return product of the given IntTuple or Shape at leaves of `target_profile`. + + This function computes products according to the structure defined by target_profile. + + :param a: The input tuple or shape + :type a: IntTuple or Shape + :param target_profile: The profile that guides how products are computed + :type target_profile: XTuple + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The resulting tuple with products computed according to target_profile + :rtype: IntTuple or Shape + :raises TypeError: If inputs have incompatible types + :raises ValueError: If inputs have incompatible shapes + """ + # Perform product at leaf of `target_profile` + if not isinstance(target_profile, tuple): + return product(a, loc=loc, ip=ip) + else: + if not isinstance(a, tuple): + raise TypeError(f"expects `a` tuple but got {a}") + + if len(a) != len(target_profile): + raise ValueError(f"expects `a` and `guide` have the same rank") + + return tuple( + product_like(x, g, loc=loc, ip=ip) for x, g in zip(a, target_profile) + ) + + +@overload +def product_each(a: IntTuple, *, loc=None, ip=None) -> IntTuple: ... +@overload +def product_each(a: Shape, *, loc=None, ip=None) -> Shape: ... + + +@dsl_user_op +def product_each(a, *, loc=None, ip=None): + """Compute products for each component of the input. + + Returns a rank(a) tuple `result` such that get(result, mode=[i]) == product(get(a, mode=[i])) + + :param a: The input tuple or shape + :type a: IntTuple or Shape + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: A tuple containing products for each component + :rtype: tuple + :raises TypeError: If input is not an IntTuple or Shape + """ + if is_integer(a): + return a + if isinstance(a, tuple): + if not a: + return 1 + else: + a_val = _pack_int_tuple(a, loc=loc, ip=ip) + res = _cute_ir.tuple_product_each(a_val, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + else: + raise TypeError(f"expects IntTuple or Shape, but got {type(a)}") + + +@dsl_user_op +def size( + a: Union[IntTuple, Shape, Layout, ComposedLayout, Tensor], + mode: List[int] = [], + *, + loc=None, + ip=None, +) -> Int: + """Return size of domain of layout or tensor. + + Computes the size (number of elements) in the domain of a layout or tensor. + For layouts, this corresponds to the shape of the coordinate space. + See https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/01_layout.md + for more details on layout domains. + + :param a: The input object whose size to compute + :type a: IntTuple, Shape, Layout, ComposedLayout or Tensor + :param mode: List of mode(s) for size calculation. If empty, computes total size, defaults to [] + :type mode: list of int, optional + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: Static size of layout or tensor if static, otherwise a Value + :rtype: int or Value + :raises ValueError: If mode contains non-integer elements + """ + if any(not isinstance(m, int) for m in mode): + raise ValueError(f"expects integer elements in mode, but got {mode}") + + if isinstance(a, (TiledMma, TiledCopy)): + return a.size + a_val = None + if not isinstance(a, (Layout, ComposedLayout, Tensor)): + a_val = _pack_int_tuple(a, loc=loc, ip=ip) + elif isinstance(a, Tensor): + a_val = a.value + else: + a_val = a + + res = _cute_ir.size(a_val, mode=mode, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) # type: ignore + + +@dsl_user_op +def shape_div(lhs: Shape, rhs: Shape, *, loc=None, ip=None) -> Shape: + """Perform element-wise division of shapes. + + This function performs element-wise division between two shapes. + + :param lhs: Left-hand side shape + :type lhs: Shape + :param rhs: Right-hand side shape + :type rhs: Shape + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The result of element-wise division + :rtype: Shape + """ + lhs = _pack_shape(lhs, loc=loc, ip=ip) + rhs = _pack_shape(rhs, loc=loc, ip=ip) + res = _cute_ir.shape_div(lhs, rhs, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +@dsl_user_op +def ceil_div(input: Shape, tiler: Tiler, *, loc=None, ip=None) -> Shape: + """ + Compute the ceiling division of a target shape by a tiling specification. + + This function computes the number of tiles required to cover the target domain. + It is equivalent to the second mode of `zipped_divide(input, tiler)`. + + :param input: A tuple of integers representing the dimensions of the target domain. + :type input: Shape + :param tiler: The tiling specification. + :type tiler: Union[Layout, Shape, Tile] + :param loc: Optional location information for IR diagnostics. + :type loc: optional + :param ip: Optional instruction pointer or context for underlying IR functions. + :type ip: optional + :return: A tuple of integers representing the number of tiles required along each dimension, + i.e. the result of the ceiling division of the input dimensions by the tiler dimensions. + :rtype: Shape + + Example: + + .. code-block:: python + + import cutlass.cute as cute + @cute.jit + def foo(): + input = (10, 6) + tiler = (3, 4) + result = cute.ceil_div(input, tiler) + print(result) # Outputs: (4, 2) + """ + input_val = _pack_shape(input, loc=loc, ip=ip) + tiler_val = _pack_tile(tiler, loc=loc, ip=ip) + res = _cute_ir.ceil_div(input=input_val, tiler=tiler_val, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +def round_up(a: IntTuple, b: IntTuple) -> IntTuple: + """ + Rounds up elements of a using elements of b. + """ + if isinstance(a, tuple): + if not a: + raise ValueError(f"inputs cannot be empty") + if not isinstance(b, tuple): + raise TypeError( + f"expects both inputs to be tuple, but got {type(a)} and {type(b)}" + ) + if rank(a) < rank(b): + raise ValueError( + f"expects rank(a) to be greater or equal than rank(b), but got {a}, {b}" + ) + b = append(b, 1, rank(a)) + return tuple(round_up(x, y) for x, y in zip(a, b)) + return ((a + b - 1) // b) * b + + +# +# Layout API (also used by tensors) +# + + +@dsl_user_op +def make_layout( + shape: Shape, *, stride: Union[Stride, None] = None, loc=None, ip=None +) -> Layout: + """Create a CuTe Layout object from shape and optional stride information. + + A Layout in CuTe represents the mapping between logical and physical coordinates of a tensor. + This function creates a Layout object that defines how tensor elements are arranged in memory. + + :param shape: Shape of the layout defining the size of each mode + :type shape: Shape + :param stride: Optional stride values for each mode, defaults to None + :type stride: Union[Stride, None] + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new Layout object with the specified shape and stride + :rtype: Layout + + **Examples:** + + .. code-block:: python + + # Create a 2D compact left-most layout with shape (4,4) + layout = make_layout((4,4)) # compact left-most layout + + # Create a left-most layout with custom strides + layout = make_layout((4,4), stride=(1,4)) # left-most layout with strides (1,4) + + # Create a layout for a 3D tensor + layout = make_layout((32,16,8)) # left-most layout + + # Create a layout with custom strides + layout = make_layout((2,2,2), stride=(4,1,2)) # layout with strides (4,1,2) + + Note: + - If stride is not provided, a default compact left-most stride is computed based on the shape + - The resulting layout maps logical coordinates to physical memory locations + - The layout object can be used for tensor creation and memory access patterns + - Strides can be used to implement: + * Row-major vs column-major layouts + * Padding and alignment + * Blocked/tiled memory arrangements + * Interleaved data formats + - Stride is keyword only argument to improve readability, e.g. + * make_layout((3,4), (1,4)) can be confusing with make_layout(((3,4), (1,4))) + * make_layout((3,4), stride=(1,4)) is more readable + """ + if stride is not None and not is_congruent(shape, stride): + raise ValueError(f"shape and stride must be congruent") + + shape_val = _pack_shape(shape, loc=loc, ip=ip) + if stride is not None: + stride_val = _pack_stride(stride, loc=loc, ip=ip) + layout_ty = _cute_ir.LayoutType.get(shape_val, stride_val) + else: + stride_val = None + layout_ty = _cute_ir.LayoutType.get(shape_val) + + return _cute_ir.make_layout( + layout_ty, shape=shape_val, stride=stride_val, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_identity_layout(shape: Shape, *, loc=None, ip=None) -> Layout: + """Create an identity layout with the given shape. + + An identity layout maps logical coordinates directly to themselves without any transformation. + This is equivalent to a layout with stride (1@0,1@1,...,1@(N-1)). + + :param shape: The shape of the layout + :type shape: Shape + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new identity Layout object with the specified shape + :rtype: Layout + + **Examples:** + + .. code-block:: python + + # Create a 2D identity layout with shape (4,4) + layout = make_identity_layout((4,4)) # stride=(1@0,1@1) + + # Create a 3D identity layout + layout = make_identity_layout((32,16,8)) # stride=(1@0,1@1,1@2) + + Note: + - An identity layout is a special case where each coordinate maps to itself + - Useful for direct coordinate mapping without any transformation + """ + if not is_int_tuple(shape): + raise TypeError(f"expects a shape input, got {type(shape)}") + shape_val = _pack_shape(shape, loc=loc, ip=ip) + return _cute_ir.make_identity_layout(shape_val, loc=loc, ip=ip) + + +@dsl_user_op +def make_ordered_layout(shape: Shape, order: Shape, *, loc=None, ip=None) -> Layout: + """Create a layout with a specific ordering of dimensions. + + This function creates a layout where the dimensions are ordered according to the + specified order parameter, allowing for custom dimension ordering in the layout. + + :param shape: The shape of the layout + :type shape: Shape + :param order: The ordering of dimensions + :type order: Shape + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new Layout object with the specified shape and dimension ordering + :rtype: Layout + + **Examples:** + + .. code-block:: python + + # Create a row-major layout + layout = make_ordered_layout((4,4), order=(1,0)) + + # Create a column-major layout + layout = make_ordered_layout((4,4), order=(0,1)) # stride=(1,4) + + # Create a layout with custom dimension ordering for a 3D tensor + layout = make_ordered_layout((32,16,8), order=(2,0,1)) # stride=(128,1,16) + + Note: + - The order parameter specifies the ordering of dimensions from fastest-varying to slowest-varying + - For a 2D tensor, (0,1) creates a column-major layout, while (1,0) creates a row-major layout + - The length of order must match the rank of the shape + """ + shape_val = _pack_shape(shape, loc=loc, ip=ip) + order_val = _pack_int_tuple(order, loc=loc, ip=ip) + return _cute_ir.make_ordered_layout( + shape=shape_val, order=order_val, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_composed_layout( + inner, offset: IntTuple, outer: Layout, *, loc=None, ip=None +) -> ComposedLayout: + """Create a composed layout by composing an inner transformation with an outer layout. + + A composed layout applies a sequence of transformations + to coordinates. The composition is defined as (inner ∘ offset ∘ outer), where the operations + are applied from right to left. + + :param inner: The inner transformation (can be a Layout or Swizzle) + :type inner: Union[Layout, Swizzle] + :param offset: An integral offset applied between transformations + :type offset: IntTuple + :param outer: The outer (right-most) layout that is applied first + :type outer: Layout + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new ComposedLayout representing the composition + :rtype: ComposedLayout + + **Examples:** + + .. code-block:: python + + # Create a basic layout + inner = make_layout(...) + outer = make_layout((4,4), stride=(E(0), E(1))) + + # Create a composed layout with an offset + composed = make_composed_layout(inner, (2,0), outer) + + Note: + - The composition applies transformations in the order: outer → offset → inner + - The stride divisibility condition must be satisfied for valid composition + - Certain compositions (like Swizzle with scaled basis) are invalid and will raise errors + - Composed layouts inherit many properties from the outer layout + """ + if not isinstance(outer, Layout): + raise TypeError( + f"expects the outer (or right-most or effectively visible) layout to be an affine layout, but got {outer}" + ) + if isinstance(inner, Swizzle) and has_scaled_basis(outer.stride): + raise TypeError(f"invalid composition {inner} o {offset} o {outer}") + offset_val = _pack_int_tuple(offset, loc=loc, ip=ip) + return _cute_ir.make_composed_layout(inner, offset_val, outer, loc=loc, ip=ip) + + +@dsl_user_op +def cosize( + a: Union[Layout, ComposedLayout, Tensor], mode: List[int] = [], *, loc=None, ip=None +): + """Return size of codomain of layout or tensor. Return static value if type is static. + + :param a: Layout, ComposedLayout, or Tensor object + :type a: Union[Layout, ComposedLayout, Tensor] + :param mode: List of mode(s) for cosize calculation + :type mode: List[int], optional + :param loc: Location information for diagnostics, defaults to None + :type loc: optional + :param ip: Instruction pointer for diagnostics, defaults to None + :type ip: optional + :return: Static size of layout or tensor (fast fold) if static, or a dynamic Value + :rtype: Union[int, Value] + """ + if any(not is_static(m) for m in mode): + raise ValueError(f"expects static mode, but got {mode}") + + if isinstance(a, _Tensor): + a = a.value + res = _cute_ir.cosize(a, mode=mode, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +@dsl_user_op +def size_in_bytes( + dtype: Type[Numeric], layout: Union[Layout, ComposedLayout], *, loc=None, ip=None +): + """Calculate the size in bytes based on its data type and layout. + + :param dtype: The DSL numeric data type + :type dtype: Type[Numeric] + :param layout: The layout of the elements. If None, the function returns 0 + :type layout: Layout, optional + :param loc: Location information for diagnostics, defaults to None + :type loc: optional + :param ip: Instruction pointer for diagnostics, defaults to None + :type ip: optional + :return: The total size in bytes. Returns 0 if the layout is None + :rtype: int + """ + if not isinstance(dtype, NumericMeta): + raise TypeError(f"dtype must be a Numeric, but got {dtype}") + + if layout is None: + return 0 + elif isinstance(layout, ComposedLayout): + if not isinstance(layout.inner, Swizzle): + raise TypeError( + f"invalid composed layout {layout}, inner must be a Swizzle" + ) + else: + return cosize(layout.outer, loc=loc, ip=ip) * dtype.width // 8 + else: + return cosize(layout, loc=loc, ip=ip) * dtype.width // 8 + + +@dsl_user_op +def coalesce(input, *, target_profile: Coord = None, loc=None, ip=None): + if target_profile: + profile_val = _pack_coord(target_profile, loc=loc, ip=ip) + return _cute_ir.coalesce(input, target_profile=profile_val, loc=loc, ip=ip) + else: + return _cute_ir.coalesce(input, loc=loc, ip=ip) + + +@dsl_user_op +def crd2idx(coord: Coord, layout, *, loc=None, ip=None): + """ + Convert a multi-dimensional coordinate into a value using the specified layout. + + This function computes the inner product of the flattened coordinate and stride: + + index = sum(flatten(coord)[i] * flatten(stride)[i] for i in range(len(coord))) + + :param coord: A tuple or list representing the multi-dimensional coordinate + (e.g., (i, j) for a 2D layout). + :type coord: Coord + :param layout: A layout object that defines the memory storage layout, including shape and stride, + used to compute the inner product. + :type layout: Layout or ComposedLayout + :param loc: Optional location information for IR diagnostics. + :type loc: optional + :param ip: Optional instruction pointer or context for underlying IR functions. + :type ip: optional + :returns: The result of applying the layout transformation to the provided coordinate. + :rtype: Any type that the layout maps to + + Example: + + .. code-block:: python + + import cutlass.cute as cute + @cute.jit + def foo(): + L = cute.make_layout((5, 4), stride=(4, 1)) + idx = cute.crd2idx((2, 3), L) + # Computed as: 2 * 4 + 3 = 11 + print(idx) + foo() # Expected output: 11 + """ + coord_val = _pack_coord(coord, loc=loc, ip=ip) + if isinstance(layout, (tuple, int)): + layout = make_layout(layout, loc=loc, ip=ip) + + res = _cute_ir.crd2idx(coord_val, layout, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +@dsl_user_op +def recast_layout(new_type_bits, old_type_bits, src_layout, *, loc=None, ip=None): + return _cute_ir.recast_layout( + new_type_bits, old_type_bits, src_layout, loc=loc, ip=ip + ) + + +@dsl_user_op +def slice_and_offset(coord, src, *, loc=None, ip=None): + layout = slice_(src, coord, loc=loc, ip=ip) + offset = crd2idx(coord, src, loc=loc, ip=ip) + return layout, offset + + +@dsl_user_op +@lru_cache_ir() +def shape( + input: Union[Shape, Tensor, Layout, Tile], *, mode=None, loc=None, ip=None +) -> Shape: + """Returns the shape of a tensor, layout or tiler. + + For shapes, this function is identical to get. + + This function extracts the shape information from the input object. For tensors and layouts, + it returns their internal shape property. For tilers, it unpacks the shape from the tile + representation. + + :param input: The object to extract shape from + :type input: Union[Tensor, Layout, Tile] + :param mode: Optional mode selector to extract specific dimensions from the shape + :type mode: Optional[int] + :param loc: Source location for MLIR operation tracking + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation + :type ip: Optional[InsertionPoint] + :return: The shape of the input object, optionally filtered by mode + :rtype: Shape + + Example: + + .. code-block:: python + + # Get shape of a layout + l0 = cute.make_layout((2, 3, 4)) + s0 = cute.shape(l0) # => (2, 3, 4) + + # Get shape of a hierarchical tiler + l1 = cute.make_layout(1) + s1 = cute.shape((l0, l1)) # => ((2, 3, 4), 1) + + # Get specific mode from a shape + s2 = cute.shape(l0, mode=0) # => 2 + """ + if is_int_tuple(input): + return get(input, mode=mode) + + if isinstance(input, (Tensor, Layout)): + shp = input.shape + else: + val = _cute_ir.get_shape(_pack_tile(input, loc=loc, ip=ip)) + shp = _unpack_x_tuple(val, loc=loc, ip=ip) + return get(shp, mode=mode) + + +# +# Pointer API +# + + +@dsl_user_op +def recast_ptr( + ptr: Pointer, + swizzle_=None, + dtype: Optional[Type[Numeric]] = None, + loc=None, + ip=None, +) -> Pointer: + if dtype is not None: + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") + dtype = dtype.mlir_type + + value_type = ptr.type.value_type if dtype is None else dtype + swizzle = swizzle_.type.attribute if swizzle_ is not None else None + res_ty = _cute_ir.PtrType.get(value_type, ptr.memspace, ptr.alignment, swizzle) + return _cute_ir.recast_iter(res_ty, ptr.value, loc=loc, ip=ip) + + +@dsl_user_op +def make_ptr( + dtype: Union[Type[Numeric], None], + value, + mem_space: AddressSpace = AddressSpace.generic, + *, + assumed_align=None, + loc=None, + ip=None, +) -> Pointer: + if dtype is None or not isinstance(dtype, NumericMeta): + raise TypeError(f"expects dtype to be a type of Numeric, but got {dtype}") + + if not isinstance(mem_space, AddressSpace): + raise TypeError(f"expects mem_space to be an AddressSpace, but got {mem_space}") + + if isinstance(value, ir.Value) and llvm.PointerType.isinstance(value.type): + value = llvm.ptrtoint(T.i64(), value) + + if not is_integer(value): + raise TypeError(f"expects integer value, but got {type(value)}") + value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value) + + bytes_per_elt = max(1, dtype.width // 8) + if assumed_align is None: + assumed_align = bytes_per_elt + + if bytes_per_elt % assumed_align != 0 and assumed_align % bytes_per_elt != 0: + raise ValueError( + f"{bytes_per_elt=} is not a multiple of {assumed_align=} and vice versa." + ) + + aligned_ty = _cute_ir.ConstrainedIntType.get(assumed_align, type(value).width) + aligned_intptr = _cute_ir.assume(aligned_ty, value.ir_value(), loc=loc, ip=ip) + + data_ty = T.i8() if dtype is None else dtype.mlir_type + ptr_ty = _cute_ir.PtrType.get(data_ty, mem_space, assumed_align) + return _cute_ir.inttoptr(ptr_ty, aligned_intptr, loc=loc, ip=ip) + + +# +# Tensor API +# + + +@dsl_user_op +def make_tensor( + iterator, layout: Union[Shape, Layout, ComposedLayout], *, loc=None, ip=None +) -> Tensor: + """Creates a tensor by composing an engine (iterator/pointer) with a layout. + + A tensor is defined as T = E ∘ L, where E is an engine (array, pointer, or counting iterator) + and L is a layout that maps logical coordinates to physical offsets. The tensor + evaluates coordinates by applying the layout mapping and dereferencing the engine + at the resulting offset. + + :param iterator: Engine component (pointer, iterator, or counting iterator) that provides + data access capabilities + :type iterator: Union[Pointer, IntTuple] + :param layout: Layout component that defines the mapping from logical coordinates to + physical offsets + :type layout: Union[Shape, Layout, ComposedLayout] + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A tensor object representing the composition E ∘ L + :rtype: Tensor + + :raises ValueError: If iterator type is not supported + + **Examples:** + + .. code-block:: python + + # Create a tensor with row-major layout + layout = make_layout((64, 128), stride=(128, 1)) + tensor = make_tensor(ptr, layout) + + # Create a tensor with hierarchical layout + layout = make_layout(((128, 8), (1, 4, 1)), stride=((32, 1), (0, 8, 4096))) + tensor = make_tensor(smem_ptr, layout) + + # Create a coord tensor + layout = make_layout(2, stride=16 * E(0)) + tensor = make_tensor(5, layout) + + Notes: + - The engine (iterator) must support random access operations + - Common engine types include raw pointers, arrays, and random-access iterators + - The layout defines both the shape (logical dimensions) and stride (physical mapping) + - Supports both direct coordinate evaluation T(c) and partial evaluation (slicing) + """ + if not isinstance(layout, (Layout, ComposedLayout)): + layout = make_layout(layout, loc=loc, ip=ip) + elif isinstance(layout, ComposedLayout) and layout.type.is_normal_layout: + layout = layout.outer + + ty = None + if is_integer(iterator) or isinstance(iterator, tuple): + iterator = _pack_int_tuple(iterator, loc=loc, ip=ip) + ty = _cute_ir.CoordTensorType.get(iterator.type, layout.type) + elif isinstance(iterator, Pointer): + iterator = iterator.value + ty = _cute_ir.MemRefType.get(iterator.type, layout.type) + else: + raise TypeError(f"unsupported iterator type, got {type(iterator)}") + + return _cute_ir.make_view(result=ty, iter=iterator, layout=layout, loc=loc, ip=ip) + + +@dsl_user_op +def make_identity_tensor(shape: Shape, *, loc=None, ip=None) -> Tensor: + """Creates an identity tensor with the given shape. + + An identity tensor maps each coordinate to itself, effectively creating a counting + sequence within the shape's bounds. This is useful for generating coordinate indices + or creating reference tensors for layout transformations. + + :param shape: The shape defining the tensor's dimensions. Can be a simple integer + sequence or a hierarchical structure ((m,n),(p,q)) + :type shape: Shape + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A tensor that maps each coordinate to itself + :rtype: Tensor + + **Examples:** + + .. code-block:: python + + # Create a simple 1D coord tensor + tensor = make_identity_tensor(6) # [0,1,2,3,4,5] + + # Create a 2D coord tensor + tensor = make_identity_tensor((3,2)) # [(0,0),(1,0),(2,0),(0,1),(1,1),(2,1)] + + # Create hierarchical coord tensor + tensor = make_identity_tensor(((2,1),3)) + # [((0,0),0),((1,0),0),((0,0),1),((1,0),1),((0,0),2),((1,0),2)] + + Notes: + - The shape parameter follows CuTe's IntTuple concept + - Coordinates are ordered colexicographically + - Useful for generating reference coordinates in layout transformations + """ + shape_val = _pack_shape(shape, loc=loc, ip=ip) + return _cute_ir.make_identity_tensor(shape_val, loc=loc, ip=ip) + + +@dsl_user_op +def make_fragment( + layout_or_shape: Union[Layout, Shape], + dtype: Type[Numeric], + *, + loc=None, + ip=None, +) -> Tensor: + if not issubclass(dtype, Numeric): + raise TypeError(f"value_type must be a type of Numeric, but got {type(dtype)}") + elem_ty = dtype.mlir_type if dtype is not Boolean else T.i8() + + # Alignment for register memory is useless(?), pick-up large enough number + # to allow .128 (> 16B) load store + alignment = 32 + layout = None + if not isinstance(layout_or_shape, Layout): + layout = make_layout(layout_or_shape, loc=loc, ip=ip) + else: + layout = layout_or_shape + + ptr_ty = _cute_ir.PtrType.get(elem_ty, AddressSpace.rmem, alignment) + res_ty = _cute_ir.MemRefType.get(ptr_ty, layout.type) + tensor = _cute_ir.memref_alloca(res_ty, layout=layout, loc=loc, ip=ip) + return _Tensor(tensor.value, dtype) + + +@overload +def make_fragment_like( + src: Tensor, dtype: Optional[Type[Numeric]], *, loc=None, ip=None +) -> Tensor: ... + + +@overload +def make_fragment_like(src: Layout, *, loc=None, ip=None) -> Layout: ... + + +@overload +def make_fragment_like(src: ComposedLayout, *, loc=None, ip=None) -> ComposedLayout: ... + + +@dsl_user_op +def make_fragment_like(src, dtype=None, *, loc=None, ip=None): + """Create tensor with a compact layout in the same shape as the source on stack. + + This function either creates a fragment tensor with compact layout in + same shape as the source layout or a new layout with the same shape as the source. + The strides of the new layout follow the order induced by the source's strides, with a + special handling of the 0th mode: it is always stride-1 and generated in column-major order + (LayoutLeft). + + :param src: The source layout or tensor whose shape will be matched + :type src: Union[Layout, ComposedLayout, Tensor] + :param dtype: The element type for the fragment tensor, defaults to None + :type dtype: Type[Numeric], optional + :param loc: Source location for MLIR operations, defaults to None + :type loc: Location, optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: InsertionPoint, optional + + :return: A new layout or fragment tensor with matching shape + :rtype: Union[Layout, Tensor] + + **Examples:** + + Creating a rmem tensor from a tensor: + + .. code-block:: python + + smem_tensor = cute.make_tensor(smem_ptr, layout) + frag_tensor = cute.make_fragment_like(smem_tensor, cutlass.Float32) + # frag_tensor will be a register-backed tensor with the same shape + + Creating a fragment with a different element type: + + .. code-block:: python + + tensor = cute.make_tensor(gmem_ptr, layout) + bool_frag = cute.make_fragment_like(tensor, cutlass.Boolean) + # bool_frag will be a register-backed tensor with Boolean elements + + **Notes** + + - When used with a Tensor, if a type is provided, it will create a new + fragment tensor with that element type. + - For layouts with ScaledBasis strides, the function creates a fragment + from the shape only. + - This function is commonly used in GEMM and other tensor operations to + create register storage for intermediate results. + + """ + if isinstance(src, (Layout, ComposedLayout)): + new_layout = None + # Create base fragment layout + if isinstance(src, Layout) and has_scaled_basis(src.stride): + # For scaled basis strides, create fragment from shape only + new_layout = _cute_ir.make_fragment_like( + make_layout(src.shape), loc=loc, ip=ip + ) + else: + # Otherwise use full source layout + new_layout = _cute_ir.make_fragment_like(src, loc=loc, ip=ip) + if dtype is not None: + # call make_fragment to convert layout to tensor + return make_fragment(new_layout, dtype, loc=loc, ip=ip) + else: + return new_layout + elif isinstance(src, Tensor): + if isinstance(src.type, _cute_ir.CoordTensorType): + if dtype is None: + raise ValueError( + "dtype must be provided when src is a coordinate tensor" + ) + + new_layout = _cute_ir.make_fragment_like( + make_layout(src.shape), loc=loc, ip=ip + ) + return make_fragment(new_layout, dtype, loc=loc, ip=ip) + else: + dtype = src.element_type if dtype is None else dtype + ty = dtype.mlir_type if dtype is not Boolean else T.i8() + new_tensor = _cute_ir.make_fragment_like( + src.value, elem_type=ty, loc=loc, ip=ip + ) + return _Tensor(new_tensor.value, dtype) + else: + raise TypeError( + f"src must be a Layout or ComposedLayout or tensor, got {type(src)}" + ) + + +@dsl_user_op +def recast_tensor( + src: Tensor, dtype: Type[Numeric], swizzle_=None, *, loc=None, ip=None +): + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") + + if dtype is Boolean: + dst_width = 8 + else: + dst_width = dtype.width + + if src.element_type is Boolean: + src_width = 8 + else: + src_width = src.element_type.width + + src_iter = recast_ptr(src.iterator, dtype=dtype, loc=loc, ip=ip) + src_layout = recast_layout(dst_width, src_width, src.layout, loc=loc, ip=ip) + return make_tensor(src_iter, src_layout, loc=loc, ip=ip) + + +@dsl_user_op +def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor: + offset = crd2idx(coord, tensor.layout, loc=loc, ip=ip) + if isinstance(tensor.iterator, Pointer): + return make_tensor(tensor.iterator + offset, tensor.layout) + elif is_integer(tensor.iterator) or isinstance(tensor.iterator, tuple): + new_iter = _cute_ir.add_offset( + _pack_int_tuple(tensor.iterator), _pack_int_tuple(offset) + ) + return make_tensor(_unpack_x_tuple(new_iter), tensor.layout) + else: + raise ValueError(f"unsupported tensor for domain_offset, got {tensor}") + + +# +# Layout algebra +# + + +@overload +def composition( + lhs: Layout, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None +) -> Layout: ... + + +@overload +def composition( + lhs: Tensor, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None +) -> Tensor: ... + + +@dsl_user_op +def composition(lhs, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None): + """ + Compose two layout representations using the CuTe layout algebra. + + Compose a left-hand layout (or tensor) with a right-hand operand into a new layout R, such that + for every coordinate c in the domain of the right-hand operand, the composed layout satisfies: + + R(c) = A(B(c)) + + where A is the left-hand operand provided as ``lhs`` and B is the right-hand operand provided as + ``rhs``. In this formulation, B defines the coordinate domain while A applies its transformation to + B's output, and the resulting layout R inherits the stride and shape adjustments from A. + + Satisfies: + cute.shape(cute.composition(lhs, rhs)) is compatible with cute.shape(rhs) + + :param lhs: The left-hand operand representing the transformation to be applied. + :type lhs: Layout or Tensor + :param rhs: The right-hand operand defining the coordinate domain. If provided as an int or tuple, + it will be converted to a tile layout. + :type rhs: Layout, Shape, or Tile, or int or tuple + :param loc: Optional location information for IR diagnostics. + :type loc: optional + :param ip: Optional instruction pointer or context for underlying IR functions. + :type ip: optional + :returns: A new composed layout R, such that for all coordinates c in the domain of ``rhs``, + R(c) = lhs(rhs(c)). + :rtype: Layout or Tensor + + Example: + + .. code-block:: python + + import cutlass.cute as cute + @cute.jit + def foo(): + # Create a layout that maps (i,j) to i*4 + j + L1 = cute.make_layout((2, 3), stride=(4, 1)) + # Create a layout that maps (i,j) to i*3 + j + L2 = cute.make_layout((3, 4), stride=(3, 1)) + # Compose L1 and L2 + L3 = cute.composition(L1, L2) + # L3 now maps coordinates through L2 then L1 + """ + rhs_val = rhs + if not isinstance(rhs, Layout) and isinstance(rhs, (int, tuple)): + rhs_val = _pack_tile(rhs, loc=loc, ip=ip) + if isinstance(lhs, _Tensor): + lhs = lhs.value + return _cute_ir.composition(lhs, rhs_val, loc=loc, ip=ip) + + +@dsl_user_op +def complement( + input: Layout, cotarget: Union[Layout, Shape], *, loc=None, ip=None +) -> Layout: + """ + Compute the complement layout of the input layout with respect to the cotarget. + + The complement of a layout A with respect to cotarget n is a layout A* such that + for every k in Z_n and c in the domain of A, there exists a unique c* in the domain + of A* where k = A(c) + A*(c*). + + This operation is useful for creating layouts that partition a space in complementary ways, + such as row and column layouts that together cover a matrix. + + :param input: The layout to compute the complement of + :type input: Layout + :param cotarget: The target layout or shape that defines the codomain + :type cotarget: Union[Layout, Shape] + :param loc: Optional location information for IR diagnostics + :type loc: optional + :param ip: Optional instruction pointer or context for underlying IR functions + :type ip: optional + :returns: The complement layout + :rtype: Layout + + Example: + + .. code-block:: python + + import cutlass.cute as cute + @cute.jit + def foo(): + # Create a right-major layout for a 4x4 matrix + row_layout = cute.make_layout((4, 4), stride=(4, 1)) + # Create a left-major layout that complements the row layout + col_layout = cute.complement(row_layout, 16) + # The two layouts are complementary under 16 + """ + if isinstance(cotarget, Layout): + return _cute_ir.complement(input, cotarget=cotarget, loc=loc, ip=ip) + else: + cotarget_val = _pack_shape(cotarget, loc=loc, ip=ip) + return _cute_ir.complement(input, cotarget=cotarget_val, loc=loc, ip=ip) + + +@dsl_user_op +def right_inverse(input: Layout, *, loc=None, ip=None) -> Layout: + if not isinstance(input, Layout): + raise TypeError(f"expects input of type Layout, but got {type(input)}") + return _cute_ir.right_inverse(input=input, loc=loc, ip=ip) + + +@dsl_user_op +def left_inverse(input: Layout, *, loc=None, ip=None) -> Layout: + if not isinstance(input, Layout): + raise TypeError(f"expects input of type Layout, but got {type(input)}") + return _cute_ir.left_inverse(input=input, loc=loc, ip=ip) + + +@overload +def logical_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def logical_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def logical_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.logical_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def zipped_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def zipped_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def zipped_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.zipped_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def tiled_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def tiled_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def tiled_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.tiled_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def flat_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def flat_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def flat_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.flat_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def raked_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def raked_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def raked_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.raked_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def blocked_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def blocked_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def blocked_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.blocked_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def logical_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +@overload +def logical_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... + + +@dsl_user_op +def logical_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None + if isinstance(target, _Tensor): + res_type = target.element_type + target = target.value + if isinstance(tiler, tuple): + tiler = _pack_tile(tiler, loc=loc, ip=ip) + res = _cute_ir.logical_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + + +@overload +def zipped_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +@overload +def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... + + +@dsl_user_op +def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None + if isinstance(target, _Tensor): + res_type = target.element_type + target = target.value + if isinstance(tiler, tuple): + tiler = _pack_tile(tiler, loc=loc, ip=ip) + res = _cute_ir.zipped_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + + +@overload +def tiled_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +@overload +def tiled_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... + + +@dsl_user_op +def tiled_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None + if isinstance(target, _Tensor): + res_type = target.element_type + target = target.value + if isinstance(tiler, tuple): + tiler = _pack_tile(tiler, loc=loc, ip=ip) + res = _cute_ir.tiled_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + + +@overload +def flat_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +@overload +def flat_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... + + +@dsl_user_op +def flat_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None + if isinstance(target, _Tensor): + res_type = target.element_type + target = target.value + if isinstance(tiler, tuple): + tiler = _pack_tile(tiler, loc=loc, ip=ip) + res = _cute_ir.flat_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + + +# +# Higher-level utilties +# + + +@dsl_user_op +def max_common_layout( + a: Union[Layout, Tensor], b: Union[Layout, Tensor], *, loc=None, ip=None +) -> Layout: + a_layout = a.layout if isinstance(a, _Tensor) else a + b_layout = b.layout if isinstance(b, _Tensor) else b + + inv_b = right_inverse(b_layout, loc=loc, ip=ip) + common = coalesce(composition(a_layout, inv_b, loc=loc, ip=ip), loc=loc, ip=ip) + + # some_ir_value == 1 generates a new IR Value which evaluates to True! + s = get(common.shape, mode=[0], loc=loc, ip=ip) + d = get(common.stride, mode=[0], loc=loc, ip=ip) + # Keep only the static identity component of the common layout + if isinstance(s, int) and isinstance(d, int) and d == 1: + # Truncate to the size of the contiguous vector (static stride-1 mode) + return composition(inv_b, get(common, mode=[0], loc=loc, ip=ip), loc=loc, ip=ip) + else: + return make_layout(1, stride=0, loc=loc, ip=ip) + + +@dsl_user_op +def max_common_vector( + a: Union[Layout, Tensor], b: Union[Layout, Tensor], *, loc=None, ip=None +) -> int: + a_layout = a.layout if isinstance(a, _Tensor) else a + b_layout = b.layout if isinstance(b, _Tensor) else b + + inv_b = right_inverse(b_layout, loc=loc, ip=ip) + common = coalesce(composition(a_layout, inv_b, loc=loc, ip=ip), loc=loc, ip=ip) + + # Keep only the static identity component of the common layout + if ( + is_static(get(common.shape, mode=[0], loc=loc, ip=ip)) + and get(common.stride, mode=[0], loc=loc, ip=ip) == 1 + ): + # Truncate to the size of the contiguous vector (static stride-1 mode) + return get(common.shape, mode=[0], loc=loc, ip=ip) + else: + return 1 + + +@dsl_user_op +def tile_to_shape( + atom: Union[Layout, ComposedLayout], + trg_shape: Shape, + order: Shape, + *, + loc=None, + ip=None, +) -> Union[Layout, ComposedLayout]: + trg_shape = _pack_shape(shape(trg_shape), loc=loc, ip=ip) + order = _pack_int_tuple(order, loc=loc, ip=ip) + return _cute_ir.tile_to_shape(atom, trg_shape, order, loc=loc, ip=ip) + + +@dsl_user_op +def local_partition( + target: Tensor, + tiler: Union[Layout, Shape], + index: Union[int, Numeric], + proj: XTuple = 1, + *, + loc=None, + ip=None, +) -> Tensor: + if isinstance(index, cutlass_arith.ArithValue): + index_val = index + else: + index_val = index.ir_value() + if index_val.type.width > 32: + raise NotImplementedError( + f"Index value should be 32-bit or smaller integer type, but got {index_val.type}" + ) + return _cute_ir.local_partition( + input=target.value, tiler=dice(tiler, proj), index=index_val, loc=loc, ip=ip + ) + + +@dsl_user_op +def local_tile( + input: Tensor, + tiler: Union[Layout, Shape], + coord: Coord, + proj: XTuple = None, + *, + loc=None, + ip=None, +) -> Tensor: + tiler_val = _pack_shape(tiler, loc=loc, ip=ip) + coord_val = _pack_coord(coord, loc=loc, ip=ip) + if proj is not None: + if not isinstance(proj, tuple): + raise TypeError(f"Expects tuple for proj, but got {type(proj)}") + proj_val = _pack_coord(proj, loc=loc, ip=ip) + proj = proj_val.type.attribute + + return _cute_ir.local_tile( + input=input.value, + tile=tiler_val, + static_tile=None, + coord=coord_val, + static_coord=None, + proj=proj, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_layout_image_mask( + lay: Layout, coord: Coord, mode: int, *, loc=None, ip=None +) -> Int16: + """ + Makes a 16-bit integer mask of the image of a layout sliced at a given mode + and accounting for the offset given by the input coordinate for the other modes. + """ + if not is_static(lay): + raise ValueError( + f"make_layout_image_mask requires the layout to be static, but got {pretty_str(lay)}" + ) + r = rank(lay) + if rank(coord) != r: + raise ValueError( + f"the rank of the coordinate must be equal to the one of the layout, but got {pretty_str(coord)}" + ) + if mode > r or mode < 0: + raise ValueError(f"expects `mode` to be in [0,rank(lay)), but got {mode}") + # Given that we require the layout to be static, we can check that the mask fits in 16 bits + # This might be too conservative but safe + if cosize(lay) > 16: + raise ValueError("the mask may not fit into a 16-bit integer") + + # Replace the mode to keep with _ in the coordinate + slicer = tuple(None if idx == mode else x for idx, x in enumerate(coord)) + # Slice the layout with the slicer above and keep track of the offset + sliced_lay, offset = slice_and_offset(slicer, lay, loc=loc, ip=ip) + # Given that we replace only one mode with _, the rank of the slice should be 1 + assert rank(sliced_lay) == 1 + + # Create the mask of the image + mcast_mask = Int16(0) + for i in range(size(sliced_lay)): + mcast_mask = mcast_mask | (1 << sliced_lay(i)) + mcast_mask <<= offset + return Int16(mcast_mask) + + +#################################################################################################### +# +# Atom +# +#################################################################################################### + + +class Op(ABC): + """ + Operation abstract base class. + """ + + pass + + +class MmaOp(Op): + """ + MMA Operation abstract base class. + """ + + @abstractmethod + def _make_trait(self, *, loc=None, ip=None, **kwargs): + pass + + +class CopyOp(Op): + """ + Copy Operation abstract base class. + """ + + @abstractmethod + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ): + pass + + +class Trait(ABC): + """ + Trait abstract base class. + + Traits are internal-only classes used by Atoms that wrap the underlying IR Value. The Python + user should only interact with Ops and Atoms. + """ + + def __init__(self, value: ir.Value) -> None: + self.value = value + + def __extract_mlir_values__(self): + return [self.value] + + def __new_from_mlir_values__(self, values): + return self.__class__(values[0]) + + def set(self, field, value, *, loc=None, ip=None) -> None: + raise NotImplementedError( + "set not implemented, the requesting Atom has likely no runtime state" + ) + + def unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: + return self.value + + +class Atom(ABC): + """ + Atom base class. + + An Atom is the composition of + + - a MMA or Copy Operation; + - an internal MMA or Copy Trait. + + An Operation is a pure Python class that is used to model a specific MMA or Copy instruction. + The Trait wraps the underlying IR Value and provides access to the metadata of the instruction + encoded using CuTe Layouts. When the Trait can be constructed straighforwardly from an + Operation, the ``make_mma_atom`` or ``make_copy_atom`` API should be used. There are cases where + constructing the metadata is not trivial and requires more information, for example to determine + the number of bytes copied per TMA instruction ("the TMA vector length"). In such cases, + dedicated helper functions are provided with an appropriate API such that the Atom is + constructed internally in an optimal fashion for the user. + """ + + def __init__(self, op: Op, trait: Trait) -> None: + self._op = op + self._trait = trait + + def __extract_mlir_values__(self): + return extract_mlir_values(self._trait) + + def __new_from_mlir_values__(self, values): + return self.__class__(self.op, new_from_mlir_values(self._trait, values)) + + @property + def op(self) -> Op: + return self._op + + @property + def type(self): + return self._trait.value.type + + @dsl_user_op + def set(self, modifier, value, *, loc=None, ip=None) -> None: + """ + Sets runtime fields of the Atom. + + Some Atoms have runtime state, for example a tcgen05 MMA Atom + + + .. code-block:: python + + tiled_mma = cute.make_tiled_mma(some_tcgen05_mma_op) + tiled_mma.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, True) + + The ``set`` method provides a way to the user to modify such runtime state. Modifiable + fields are provided by arch-specific enumerations, for example ``tcgen05.Field``. The Atom + instance internally validates the field as well as the value provided by the user to set + the field to. + """ + self._trait.set(modifier, value, loc=loc, ip=ip) + + def _unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: + return self._trait.unpack(loc=loc, ip=ip, **kwargs) + + +#################################################################################################### +# +# MMA Atoms, TiledMma, and ThrMma +# +#################################################################################################### + + +class MmaAtom(Atom): + """ + The MMA Atom class. + """ + + def __str__(self) -> str: + res = "MMA Atom\n" + res += " ThrID: " + pretty_str(self.thr_id) + "\n" + res += " Shape MNK: " + pretty_str(self.shape_mnk) + "\n" + res += " TV Layout A: " + pretty_str(self.tv_layout_A) + "\n" + res += " TV Layout B: " + pretty_str(self.tv_layout_B) + "\n" + res += " TV Layout C: " + pretty_str(self.tv_layout_C) + return res + + # + # Properties + # + + @property + def thr_id(self) -> Layout: + return _cute_ir.static(self._trait.value.type.thr_id) + + @property + def shape_mnk(self) -> Shape: + return _unpack_x_tuple(self._trait.value.type.shape_mnk) + + @property + def tv_layout_A(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_a_tv) + + @property + def tv_layout_B(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_b_tv) + + @property + def tv_layout_C(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_c_tv) + + # + # make_fragment + # + + @dsl_user_op + def make_fragment_A(self, input, *, loc=None, ip=None): + # input could be memref/shape/layout for tmem based fragment + if isinstance(input, _Tensor): + if self.op is not None: + self.op._verify_fragment_A(input, loc=loc, ip=ip) + input = input.value + if isinstance(input, tuple): + input = _pack_shape(input, loc=loc, ip=ip) + return _cute_ir.mma_make_fragment( + _cute_ir.MmaOperand.A, + self._trait.value, + input, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def make_fragment_B(self, input, *, loc=None, ip=None): + if isinstance(input, _Tensor): + if self.op is not None: + self.op._verify_fragment_B(input, loc=loc, ip=ip) + input = input.value + return _cute_ir.mma_make_fragment( + _cute_ir.MmaOperand.B, + self._trait.value, + input, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def make_fragment_C(self, input, *, loc=None, ip=None): + # input could be memref/shape/layout for tmem based fragment + if isinstance(input, _Tensor): + input = input.value + if isinstance(input, tuple): + input = _pack_shape(input, loc=loc, ip=ip) + return _cute_ir.mma_make_fragment( + _cute_ir.MmaOperand.C, + self._trait.value, + input, + loc=loc, + ip=ip, + ) + + +class TiledMma(MmaAtom): + """ + The tiled MMA class. + """ + + def __str__(self) -> str: + res = "Tiled MMA\n" + res += " Thr Layout VMNK: " + pretty_str(self.thr_layout_vmnk) + "\n" + res += " Permutation MNK: " + pretty_str(self.permutation_mnk) + "\n" + res += "MMA Atom\n" + res += " ThrID: " + pretty_str(self.thr_id) + "\n" + res += " Shape MNK: " + pretty_str(self.shape_mnk) + "\n" + res += " TV Layout A: " + pretty_str(self.tv_layout_A) + "\n" + res += " TV Layout B: " + pretty_str(self.tv_layout_B) + "\n" + res += " TV Layout C: " + pretty_str(self.tv_layout_C) + return res + + # + # Properties + # + + @property + def tv_layout_A_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_a_tv_tiled) + + @property + def tv_layout_B_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_b_tv_tiled) + + @property + def tv_layout_C_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_c_tv_tiled) + + @property + def permutation_mnk(self) -> Tile: + return _unpack_x_tuple(self._trait.value.type.permutation_mnk) + + @property + def thr_layout_vmnk(self) -> Layout: + return _cute_ir.static(self._trait.value.type.thr_layout_vmnk) + + @property + def size(self) -> int: + return self._trait.value.type.size + + # + # Tiler + # + + def get_tile_size(self, mode_idx: int) -> Shape: + assert (mode_idx >= 0) and (mode_idx < 3) + perm_tile = self.permutation_mnk[mode_idx] + if perm_tile is None: + thr_layout_vmnk = self.thr_layout_vmnk + atom_shape_mnk = self.shape_mnk + return size(atom_shape_mnk, mode=[mode_idx]) * size( + thr_layout_vmnk, mode=[mode_idx + 1] + ) + else: + return size(perm_tile) + + # + # get_slice + # + + def get_slice(self, thr_idx: Union[int, Int32]) -> "ThrMma": + return ThrMma(self.op, self._trait, thr_idx) + + # + # partition_shape + # + + def _partition_shape(self, operand_id, shape, *, loc=None, ip=None): + shape = _pack_shape(shape, loc=loc, ip=ip) + return _unpack_x_tuple( + _cute_ir.tiled_mma_partition_shape( + operand_id, self._trait.value, shape, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def partition_shape_A(self, shape_mk, *, loc=None, ip=None): + return self._partition_shape(_cute_ir.MmaOperand.A, shape_mk, loc=loc, ip=ip) + + @dsl_user_op + def partition_shape_B(self, shape_nk, *, loc=None, ip=None): + return self._partition_shape(_cute_ir.MmaOperand.B, shape_nk, loc=loc, ip=ip) + + @dsl_user_op + def partition_shape_C(self, shape_mn, *, loc=None, ip=None): + return self._partition_shape(_cute_ir.MmaOperand.C, shape_mn, loc=loc, ip=ip) + + # + # _thrfrg + # + + @overload + def _thrfrg(self, operand_id, input: Layout, *, loc=None, ip=None) -> Layout: ... + + @overload + def _thrfrg(self, operand_id, input: Tensor, *, loc=None, ip=None) -> Tensor: ... + + def _thrfrg(self, operand_id, input, *, loc=None, ip=None) -> Union[Tensor, Layout]: + if isinstance(input, Tensor): + return make_tensor( + input.iterator, + self._thrfrg(operand_id, input.layout, loc=loc, ip=ip), + ) + elif isinstance(input, Layout): + if not is_static(input.type): + raise ValueError(f"Expects a static layout but got {input.type}") + return _cute_ir.static( + self._trait.value.type.thrfrg(operand_id, input), loc=loc, ip=ip + ) + + raise ValueError( + f"Expects a layout or a tensor as input but got {type(input)=}" + ) + + def _thrfrg_A( + self, input: Union[Layout, Tensor], *, loc=None, ip=None + ) -> Union[Layout, Tensor]: + return self._thrfrg(_cute_ir.MmaOperand.A, input, loc=loc, ip=ip) + + def _thrfrg_B( + self, input: Union[Layout, Tensor], *, loc=None, ip=None + ) -> Union[Layout, Tensor]: + return self._thrfrg(_cute_ir.MmaOperand.B, input, loc=loc, ip=ip) + + def _thrfrg_C( + self, input: Union[Layout, Tensor], *, loc=None, ip=None + ) -> Union[Layout, Tensor]: + return self._thrfrg(_cute_ir.MmaOperand.C, input, loc=loc, ip=ip) + + +class ThrMma(TiledMma): + """ + The thread MMA class for modeling a thread-slice of a tiled MMA. + """ + + def __init__(self, op: Op, trait: Trait, thr_idx: Union[int, Int32]) -> None: + super().__init__(op, trait) + self._thr_idx = thr_idx + + def __new_from_mlir_values__(self, values): + return self.__class__( + self.op, new_from_mlir_values(self._trait, values), self.thr_idx + ) + + @property + def thr_idx(self): + return self._thr_idx + + @dsl_user_op + def partition_A(self, input_mk: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_mma_partition( + _cute_ir.MmaOperand.A, + self._trait.value, + input_mk.value, + thr_idx, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def partition_B(self, input_nk: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_mma_partition( + _cute_ir.MmaOperand.B, + self._trait.value, + input_nk.value, + thr_idx, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def partition_C(self, input_mn: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_mma_partition( + _cute_ir.MmaOperand.C, + self._trait.value, + input_mn.value, + thr_idx, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mma_atom(op: MmaOp, *, loc=None, ip=None, **kwargs) -> MmaAtom: + """ + Makes an MMA Atom from an MMA Operation. + + This function creates an MMA Atom from a given MMA Operation. Arbitrary kw arguments can be + provided for Op-specific additional parameters. They are not used as of today. + + :param op: The MMA Operation to construct an Atom for + :type op: MmaOp + :return: The MMA Atom + :rtype: MmaAtom + """ + trait = op._make_trait(loc=loc, ip=ip, **kwargs) + return MmaAtom(op, trait) + + +@dsl_user_op +def make_tiled_mma( + op_or_atom: Union[Op, MmaAtom], + atom_layout_mnk=(1, 1, 1), + permutation_mnk=None, + *, + loc=None, + ip=None, + **kwargs, +) -> TiledMma: + """ + Makes a tiled MMA from an MMA Operation or an MMA Atom. + + :param op_or_atom: The MMA Operation or Atom + :type op_or_atom: Union[Op, MmaAtom] + :param atom_layout_mnk: A Layout describing the tiling of Atom across threads + :type atom_layout_mnk: Layout + :param permutation_mnk: A permutation Tiler describing the tiling of Atom across values including any permutation of such tiling + :type permutation_mnk: Tiler + :return: The resulting tiled MMA + :rtype: TiledMma + """ + if isinstance(op_or_atom, Op): + op = op_or_atom + atom = make_mma_atom(op_or_atom, loc=loc, ip=ip, **kwargs) + elif isinstance(op_or_atom, MmaAtom): + op = op_or_atom.op + atom = op_or_atom + else: + raise TypeError( + f"expected an MMA Op or Atom, but got an instance of {type(op_or_atom)}" + ) + if isinstance(atom_layout_mnk, tuple): + atom_layout_mnk = make_layout(atom_layout_mnk, loc=loc, ip=ip) + if rank(atom_layout_mnk) != 3: + raise ValueError(f"expects rank-3 MNK atom layout, but got {atom_layout_mnk}") + permutation_mnk_ty = None + if permutation_mnk is not None: + permutation_mnk_ty = _pack_tile(permutation_mnk, loc=loc, ip=ip).type + ty = _cute_nvgpu_ir.TiledMmaType.get( + atom._trait.value.type, + atom_layout_mnk.type, + permutation_mnk_ty, + ) + val = _cute_ir.make_tiled_mma(ty, atom._trait.value, loc=loc, ip=ip) + # Instead of modifying atom which might have been provided by the user, create a brand new + # trait instance and replace the Atom ir.Value with the tiled one + trait = new_from_mlir_values(atom._trait, [val]) + return TiledMma(op, trait) + + +#################################################################################################### +# +# Copy Atoms, TiledCopy, and ThrCopy +# +#################################################################################################### + + +class CopyAtom(Atom): + """ + The Copy Atom class. + """ + + def __str__(self) -> str: + res = "Copy Atom\n" + res += " ThrID: " + str(self.thr_id) + "\n" + res += " TV Layout Src: " + str(self.layout_src_tv) + "\n" + res += " TV Layout Dst: " + str(self.layout_dst_tv) + "\n" + res += " Value type: " + str(self._trait.value.type.value_type) + return res + + # + # Properties + # + + @property + def value_type(self) -> Type[Numeric]: + return Numeric.from_mlir_type(self._trait.value.type.value_type) + + @property + def thr_id(self) -> Layout: + return _cute_ir.static(self._trait.value.type.thr_id) + + @property + def layout_src_tv(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_src_tv) + + @property + def layout_dst_tv(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_dst_tv) + + +class TiledCopy(CopyAtom): + """ + The tiled Copy class. + """ + + def __str__(self) -> str: + res = "Tiled Copy\n" + res += " Tiler MN: " + pretty_str(self.tiler_mn) + "\n" + res += " TV Layout tiled: " + str(self.layout_tv_tiled) + "\n" + res += "Copy Atom\n" + res += " ThrID: " + str(self.thr_id) + "\n" + res += " TV Layout Src: " + str(self.layout_src_tv) + "\n" + res += " TV Layout Dst: " + str(self.layout_dst_tv) + "\n" + res += " Value type: " + str(self._trait.value.type.value_type) + return res + + # + # Properties + # + + @property + def layout_tv_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_tv_tiled) + + @property + def tiler_mn(self) -> Tile: + return _unpack_x_tuple(self._trait.value.type.tiler_mn) + + @property + def layout_src_tv_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_src_tv_tiled) + + @property + def layout_dst_tv_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_dst_tv_tiled) + + @property + def size(self) -> int: + return self._trait.value.type.size + + # + # get_slice and retile + # + + def get_slice(self, thr_idx: Union[int, Int32]) -> "ThrCopy": + return ThrCopy(self.op, self._trait, thr_idx) + + @dsl_user_op + def retile(self, src, *, loc=None, ip=None): + return _cute_ir.tiled_copy_retile( + tiled_copy=self._trait.value, input=src.value, loc=loc, ip=ip + ) + + +class ThrCopy(TiledCopy): + """ + The thread Copy class for modeling a thread-slice of a tiled Copy. + """ + + def __init__(self, op: Op, trait: Trait, thr_idx: Union[int, Int32]) -> None: + super().__init__(op, trait) + self._thr_idx = thr_idx + + def __new_from_mlir_values__(self, values): + return self.__class__( + self.op, new_from_mlir_values(self._trait, values), self.thr_idx + ) + + @property + def thr_idx(self): + return self._thr_idx + + @dsl_user_op + def partition_S(self, src: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_copy_partition_S( + self._trait.value, src.value, thr_idx, loc=loc, ip=ip + ) + + @dsl_user_op + def partition_D(self, dst: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_copy_partition_D( + self._trait.value, dst.value, thr_idx, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_copy_atom( + op: CopyOp, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs +) -> CopyAtom: + """ + Makes a Copy Atom from a Copy Operation. + + This function creates a Copy Atom from a given Copy Operation. Arbitrary kw arguments can be + provided for Op-specific additional parameters. + + Example: + + .. code-block:: python + + op = cute.nvgpu.CopyUniversalOp() + atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64) + + :param op: The Copy Operation to construct an Atom for + :type op: CopyOp + :param copy_internal_type: An internal data type used to construct the source/destination layouts in unit of tensor elements + :type copy_internal_type: Type[Numeric] + :return: The Copy Atom + :rtype: CopyAtom + """ + trait = op._make_trait(copy_internal_type, loc=loc, ip=ip, **kwargs) + return CopyAtom(op, trait) + + +@dsl_user_op +def make_layout_tv( + thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None +) -> Tuple[Shape, Layout]: + """Create a thread-value layout for partitioning data tensors. + + This function creates a thread-value layout that maps between ``(thread_idx, value_idx)`` + coordinates and logical ``(M,N)`` coordinates. The thread layout must be compact to ensure + proper partitioning. + + This implements the thread-value partitioning pattern shown in + Figure TVLayout, where data is partitioned across threads and values within each thread. + + :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) + :type thr_layout: Layout + :param val_layout: Layout mapping from ``(ValueM,ValueN)`` coordinates to value IDs within each thread + :type val_layout: Layout + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tuple containing ``tiler_mn`` and ``layout_tv`` + :rtype: Tuple[Shape, Layout] + + where: + * ``tiler_mn`` is tiler and ``shape(tiler_mn)`` is compatible with ``shape(zipped_divide(x, tiler_mn))[0]`` + * ``layout_tv``: Thread-value layout mapping (thread_idx, value_idx) -> (M,N) + + **Example:** + + .. code-block:: python + + tiler_mn, layout_tv = cute.make_layout_tv( + cute.make_layout((4, 8), stride=(8, 1)), cute.make_layout(2, stride=1) + ) + + Above code creates a TV layout that maps between thread/value coordinates + and the logical coordinates in a 8x8 matrix with: + + * thread block layout ``(4,8):(8,1)`` + * 2 elements per thread + """ + + if not isinstance(thr_layout, Layout): + raise TypeError(f"expected a Layout for thr_layout, but got {type(thr_layout)}") + if not isinstance(val_layout, Layout): + raise TypeError(f"expected a Layout for val_layout, but got {type(val_layout)}") + + # Take the raked_products to compute the Layout_MN + # (M,N) -> (thr_idx, val_idx) + layout_mn = raked_product(thr_layout, val_layout, loc=loc, ip=ip) + thr_size = size(thr_layout, loc=loc, ip=ip) + val_size = size(val_layout, loc=loc, ip=ip) + tmp = make_layout((thr_size, val_size), loc=loc, ip=ip) + # (thr_idx, val_idx) -> (M,N) + layout_tv = composition( + right_inverse(layout_mn, loc=loc, ip=ip), tmp, loc=loc, ip=ip + ) + + tiler_mn = product_each(layout_mn.shape, loc=loc, ip=ip) + + return (tiler_mn, layout_tv) + + +def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): + if type(tiler_mn) is tuple: + tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip) + + assert isinstance(tiler_mn, ir.Value) and _cute_ir.TileType.isinstance( + tiler_mn.type + ), f"tiler_mn must be a Tile, but got {type(tiler_mn)}" + assert is_static(layout_tv.type) and is_static( + tiler_mn.type + ), "layout tv and tiler mn must be static" + tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get( + atom.type, layout_tv.type, tiler_mn.type + ) + + val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip) + # Instead of modifying atom which might have been provided by the user, create a brand new + # trait instance and replace the Atom ir.Value with the tiled one + trait = new_from_mlir_values(atom._trait, [val]) + return TiledCopy(atom.op, trait) + + +def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): + """Create a tiled type given a TV partitioner and tiler. + + :param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. + :type atom: CopyAtom + :param layout_tv: Thread-value layout + :type layout_tv: Layout + :param tiler_mn: Tile size + :type tiler_mn: Tiler + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + + +@dsl_user_op +def make_tiled_copy_tv( + atom: CopyAtom, thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None +) -> TiledCopy: + """Create a tiled copy given separate thread and value layouts. + + A TV partitioner is inferred based on the input layouts. The input thread layout + must be compact. + + :param atom: Copy atom + :type atom: CopyAtom + :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) + :type thr_layout: Layout + :param val_layout: Layout mapping from ``(ValueM,ValueN)`` coordinates to value IDs + :type val_layout: Layout + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + tiler_mn, layout_tv = make_layout_tv(thr_layout, val_layout, loc=loc, ip=ip) + tiler_mn = _pack_tile(product_each(tiler_mn, loc=loc, ip=ip), loc=loc, ip=ip) + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + + +@dsl_user_op +def make_tiled_copy_A(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the A-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_A_tiled, + (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_B(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the B-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_B_tiled, + (tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_C(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the C-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_C_tiled, + (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the Src-Layout of tiled_copy. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_copy: Tiled copy + :type tiled_copy: TiledCopy + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, tiled_copy.layout_src_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the Dst-Layout of tiled_copy. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_copy: Tiled copy + :type tiled_copy: TiledCopy + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, tiled_copy.layout_dst_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None): + """Create the smallest tiled copy that can retile LayoutC_TV for use with pipelined epilogues with subtiled stores. + + :param atom: Copy atom + :type atom: CopyAtom + :param mma: Tiled MMA + :type mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for partitioner + :rtype: TiledCopy + + :raises ValueError: If the number value of CopyAtom's source layout is greater than the size of TiledMma's LayoutC_TV + """ + # Truncate the V-layout to just the Copy_Atom, keep the V-order + layoutC_tv = mma.tv_layout_C_tiled + val_layout_src = atom.layout_src_tv + num_val_src = size(val_layout_src, mode=[1], loc=loc, ip=ip) + num_val_layoutC_tv = size(layoutC_tv, mode=[1], loc=loc, ip=ip) + if num_val_src > num_val_layoutC_tv: + raise ValueError( + f"The number value of CopyAtom's source layout {num_val_src} " + f"is greater than the size of TiledMma's LayoutC_TV {num_val_layoutC_tv}" + ) + layout_TV = composition( + layoutC_tv, + make_layout( + (size(layoutC_tv, mode=[0], loc=loc, ip=ip), num_val_src), loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + # Recompute tiler and restride the TV layout for the new tiler + + # Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them + # Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA + mma_tiler = (mma.get_tile_size(0), mma.get_tile_size(1)) + + tiler_0 = filter( + composition( + make_layout(mma_tiler, stride=(1, 0), loc=loc, ip=ip), + layout_TV, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + tiler_1 = filter( + composition( + make_layout(mma_tiler, stride=(0, 1), loc=loc, ip=ip), + layout_TV, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + tiler = (tiler_0, tiler_1) + + tile2mma = composition( + make_layout(mma_tiler, loc=loc, ip=ip), tiler, loc=loc, ip=ip + ) + layout_tv = composition( + left_inverse(tile2mma, loc=loc, ip=ip), layout_TV, loc=loc, ip=ip + ) + + tiler_mn = _pack_tile(tiler, loc=loc, ip=ip) + + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + + +#################################################################################################### +# +# cute.gemm and cute.copy +# +#################################################################################################### + + +@dsl_user_op +def gemm( + atom: MmaAtom, + d: Tensor, + a: Tensor, + b: Tensor, + c: Tensor, + *, + loc=None, + ip=None, + **kwargs, +) -> None: + """The GEMM algorithm. + + Computes ``D <- A * B + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g. + warpgroup-wide or tcgen05 MMAs) require manually setting an "accumulate" boolean field. + + All tensors must be partitioned according to the provided MMA Atom. + + For MMA Atoms that require single-threaded execution, the gemm op automatically handles thread + election internally. Manual thread selection is not required in such cases. + + Following dispatch rules are supported: + + - Dispatch [1]: (V) x (V) => (V) => (V,1,1) x (V,1,1) => (V,1,1) + - Dispatch [2]: (M) x (N) => (M,N) => (1,M,1) x (1,N,1) => (1,M,N) + - Dispatch [3]: (M,K) x (N,K) => (M,N) => (1,M,K) x (1,N,K) => (1,M,N) + - Dispatch [4]: (V,M) x (V,N) => (V,M,N) => (V,M,1) x (V,N,1) => (V,M,N) + - Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) + + :param atom: MMA atom + :type atom: MmaAtom + :param d: Destination tensor + :type d: Tensor + :param a: First source tensor + :type a: Tensor + :param b: Second source tensor + :type b: Tensor + :param c: Third source tensor + :type c: Tensor + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR, defaults to None + :type ip: Optional[InsertionPoint], optional + :param kwargs: Additional keyword arguments + :type kwargs: dict + :return: None + :rtype: None + """ + + a_rank = rank(a.shape) + b_rank = rank(b.shape) + c_rank = rank(c.shape) + d_rank = rank(d.shape) + + if a_rank != b_rank: + raise ValueError("`a` and `b` must have the same rank") + + if c_rank != d_rank: + raise ValueError("`c` and `d` must have the same rank") + + if a_rank == 1: + if c_rank > 2: + raise ValueError("`c` must have rank <= 2 when `a` has rank 1") + elif a_rank == 2: + if c_rank not in (2, 3): + raise ValueError("`c` must have rank 2 or 3 when `a` has rank 2") + elif a_rank == 3: + if c_rank != 3: + raise ValueError("`c` must have rank 3 when `a` has rank 3") + + value = atom._unpack(loc=loc, ip=ip, **kwargs) + return _cute_ir.gemm(value, d.value, a.value, b.value, c.value, loc=loc, ip=ip) + + +@dsl_user_op +def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: + """Performs a basic element-wise copy. + + This functions **assumes** the following pre-conditions: + 1. `size(src) == size(dst)` + + When the `src` and `dst` shapes are static, the pre-conditions are actually verified and the + element-wise loop is fully unrolled. + + :param src: Source tensor + :type src: Tensor + :param dst: Destination tensor + :type dst: Tensor + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + """ + + if is_static(src.shape) and is_static(dst.shape): + simt_copy_ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( + src.element_type.mlir_type, src.element_type.width + ) + simt_copy = _cute_ir.atom(simt_copy_ty, loc=loc, ip=ip) + return _cute_ir.copy(simt_copy, src.value, dst.value, loc=loc, ip=ip) + + s = size(dst, loc=loc, ip=ip) + # Always generate an scf.for Op when one of the tensors is dynamic + for i in for_generate(0, s): + dst[i] = src[i] + yield_out() + + +@dsl_user_op +def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: + """Performs a basic predicated element-wise copy. + + This functions **assumes** the following pre-conditions: + 1. `size(src) == size(dst)` + 2. `size(src) == size(pred)` + + When all shapes are static, the pre-conditions are actually verified and the element-wise loop + is fully unrolled. + + """ + if src.element_type.width != dst.element_type.width: + raise NotImplementedError( + "basic_copy_if currently only supports equal source and destination " + "element type bit width" + ) + + if is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape): + return _basic_copy_if_static(pred, src, dst, loc=loc, ip=ip) + + s = size(dst, loc=loc, ip=ip) + # Always generate an scf.for Op when one of the tensors is dynamic + for i in for_generate(0, s): + if_generate(pred[i], lambda: dst.__setitem__(i, src[i])) + yield_out() + + +# Version of basic_copy_if when src and dst have static shapes +# - verify size(src) == size(dst) == size(prd) +# - fully unroll the loop for now +def _basic_copy_if_static( + pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None +) -> None: + assert is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape) + if size(src, loc=loc, ip=ip) != size(dst, loc=loc, ip=ip): + raise ValueError( + "basic_copy expects the size of source, destination, and predicate tensors to match" + ) + # Fully unrolled loop in the static case for now + for i in range(size(dst, loc=loc, ip=ip)): + if_generate(pred[i], lambda: dst.__setitem__(i, src[i])) + + +@dsl_user_op +def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: + """ + Auto-vectorizing SIMT copy policy. + + Given a source and destination tensors that are statically shaped, this policy figures out the + largest safe vector width that the copy instruction can take and performs the copy. + """ + if src.element_type.width != dst.element_type.width: + raise NotImplementedError( + "autovec_copy currently only supports equal source and destination " + "element type bit width" + ) + + # We are going to dispatch to copy-with-atom which requires shapes to be static + if not is_static(src.shape) or not is_static(dst.shape): + raise ValueError( + "autovec_copy expects source and destination tensors to be statically shaped" + ) + + vec_layout = max_common_layout(src, dst, loc=loc, ip=ip) + num_common_elements = size(vec_layout, loc=loc, ip=ip) + + # Next we construct an upper-bound on the number bits that can be vectorized by considering + # - the maximum alignment of the layouts + # - the maximum alignment of the pointers + + upper_bound = math.gcd(src.layout.max_alignment, dst.layout.max_alignment) + upper_bound = math.gcd(upper_bound, num_common_elements) + upper_bound *= src.element_type.width + + # For our instructions, the alignment of the pointer is an upper bound to the vector width + # max_alignment, as opposed to alignment, takes into account possible address swizzling + upper_bound = math.gcd(upper_bound, src.iterator.max_alignment * 8) + upper_bound = math.gcd(upper_bound, dst.iterator.max_alignment * 8) + + # Finally, we put a cap at 128b + num_bits_per_copy = math.gcd(upper_bound, 128) + + if (num_common_elements > 1) and (num_bits_per_copy % 8 == 0): + num_common_elements = num_bits_per_copy // src.element_type.width + + # 2 step logical divides ensuring that the divides are valid at every step + vec_src = logical_divide(src, vec_layout, loc=loc, ip=ip) + vec_dst = logical_divide(dst, vec_layout, loc=loc, ip=ip) + tiled_src = logical_divide( + vec_src, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip + ) + tiled_dst = logical_divide( + vec_dst, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip + ) + + # Dispatch to copy with atom + simt_type = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( + src.element_type.mlir_type, num_bits_per_copy + ) + simt_copy = _cute_ir.atom(simt_type, loc=loc, ip=ip) + return _cute_ir.copy( + simt_copy, tiled_src.value, tiled_dst.value, loc=loc, ip=ip + ) + + # Failed to vectorize, use a basic copy + basic_copy(src, dst, loc=loc, ip=ip) + + +@dsl_user_op +def copy( + atom: CopyAtom, + src: Tensor, + dst: Tensor, + *, + pred: Optional[Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + """ + The Copy algorithm. + + The "copy with Atom" expects source and destination tensors to be partitioned according to the + provided Copy Atom. Some Atoms require additional Op-specific kw arguments, for example TMA + copies: + + .. code-block:: python + + cute.copy(tma_atom, src, dst, tma_bar_ptr=mbar_ptr, mcast_mask=mask) + + An additional predication tensor can be provided. If the partitioned tensors have the following + logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile + consistent with ``(ATOM_REST,REST_M,...)``. + + For Copy Atoms that require single-threaded execution, the copy op automatically handles thread + election internally. Manual thread selection is not required in such cases. + """ + if isinstance(src.type, _cute_ir.MemRefType) and isinstance( + dst.type, _cute_ir.MemRefType + ): + if src.element_type.width != dst.element_type.width: + raise TypeError( + "`copy` currently only supports equal source and destination " + "element type bit width" + ) + + value = atom._unpack(loc=loc, ip=ip, **kwargs) + if isinstance(pred, Tensor): + pred = pred.value + return _cute_ir.copy(value, src.value, dst.value, pred=pred, loc=loc, ip=ip) + + +@dsl_user_op +def copy_atom_call( + atom: CopyAtom, + src: Tensor, + dst: Tensor, + *, + pred: Optional[Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + """ + Execute a single copy atom operation. + + The copy_atom_call operation executes a copy atom with the given operands. + Following src/dst layout of atom are valid: + * ((atom_v)) + * (atom_v) + + Note: The format ((atom_v, rest_v)) is NOT valid for copy_atom_call since it would + require multiple atom operations, which contradicts the definition of a single copy atom call. + + Examples: + + .. code-block:: python + + # Call a copy atom operation + cute.copy_atom_call(copy_atom, src_tensor, dst_tensor) + + An additional predication tensor can be provided. If the partitioned tensors have the following + logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile + consistent with ``(ATOM_REST,REST_M,...)``. + """ + if isinstance(src.type, _cute_ir.MemRefType) and isinstance( + dst.type, _cute_ir.MemRefType + ): + if src.element_type.width != dst.element_type.width: + raise TypeError( + "`copy_atom_call` currently only supports equal source and destination " + "element type bit width" + ) + + value = atom._unpack(loc=loc, ip=ip, **kwargs) + if isinstance(pred, Tensor): + pred = pred.value + return _cute_ir.copy_atom_call( + value, src.value, dst.value, pred=pred, loc=loc, ip=ip + ) + + +def prefetch(atom: CopyAtom, src: Tensor, *, loc=None, ip=None) -> None: + """ + The Prefetch algorithm. + + The "prefetch" expects source tensors to be partitioned according to the provided Copy Atom. + Prefetch is used for loading tensors from global memory to L2. + + Prefetch accepts Copy Atom but not all are allowed. Currently, only support for tma load tensor prefetch. + + .. code-block:: python + + cute.prefetch(tma_atom, src) + + For Copy Atoms that require single-threaded execution, the copy op automatically handles thread + election internally. Manual thread selection is not required in such cases. + """ + dummy_tma_bar_ptr = make_ptr(Int64, 0, AddressSpace.smem, loc=loc, ip=ip) + value = atom._unpack(loc=loc, ip=ip, tma_bar_ptr=dummy_tma_bar_ptr) + return _cute_ir.prefetch(value, src.value, loc=loc, ip=ip) + +#################################################################################################### +# +# TensorSSA class (experimental) +# +#################################################################################################### + + +class ReductionOp(Enum): + ADD = auto() + MUL = auto() + MAX = auto() + MIN = auto() + INC = auto() + DEC = auto() + AND = auto() + OR = auto() + XOR = auto() + + def __str__(self): + return self.name.lower() + + +class TensorSSA(cutlass_arith.ArithValue): + """A class representing thread local data from CuTe Tensor in value semantic and immutable. + + :param value: Flatten vector as ir.Value holding logic data of SSA Tensor + :type value: ir.Value + :param shape: The nested shape in CuTe of the vector + :type shape: Shape + :param dtype: Data type of the tensor elements + :type dtype: Type[Numeric] + + :ivar _shape: The nested shape in CuTe of the vector + :ivar _dtype: Data type of the tensor elements + + :raises ValueError: If shape is not static + """ + + def __init__(self, value, shape: Shape, dtype: Type[Numeric]): + """Initialize a new TensorSSA object. + + :param value: Flatten vector as ir.Value holding logic data of SSA Tensor + :type value: ir.Value + :param shape: The nested shape in CuTe of the vector + :type shape: Shape + :param dtype: Data type of the tensor elements + :type dtype: Type[Numeric] + :raises ValueError: If shape is not static + """ + if not is_static(shape): + raise ValueError("dynamic shape is not supported") + + signed = dtype.signed if issubclass(dtype, Integer) else False + super().__init__(value, signed) + + self._shape = shape + self._dtype = dtype + self._layout = None + + @property + def dtype(self) -> Type[Numeric]: + return self._dtype + + @property + def element_type(self) -> Type[Numeric]: + return self._dtype + + @abstractmethod + def __extract_mlir_values__(self): + return [self] + + @abstractmethod + def __new_from_mlir_values__(self, values): + return TensorSSA(values[0], self.shape, self.dtype) + + def __str__(self): + return f"tensor_value<{self.type} o {self.shape}>" + + @property + def shape(self): + return self._shape + + @overload + def _apply_op(self, op, other: "TensorSSA", flip, *, loc, ip) -> "TensorSSA": ... + + @overload + def _apply_op( + self, op, other: cutlass_arith.ArithValue, flip, *, loc, ip + ) -> "TensorSSA": ... + + @overload + def _apply_op( + self, op, other: Union[int, float, bool], flip, *, loc, ip + ) -> "TensorSSA": ... + + def _apply_op(self, op, other, flip=False, *, loc=None, ip=None): + def get_attr_for_type(ty, value): + if isinstance(ty, ir.IntegerType): + return ir.IntegerAttr.get(ty, value) + elif isinstance(ty, ir.FloatType): + return ir.FloatAttr.get(ty, value) + else: + raise TypeError(f"unsupported type: {ty}") + + # Canonicalize into Numeric + if isinstance(other, (int, float, bool)) or ( + not isinstance(other, TensorSSA) + and isinstance(other, cutlass_arith.ArithValue) + ): + other = as_numeric(other) + + # Promote types + lhs, rhs, res_type = _binary_op_type_promote(self, other) + + # Promote scalar to vector + if not isinstance(rhs, TensorSSA): + if isinstance(rhs, Numeric): + vect_val = vector.broadcast(lhs.type, rhs.ir_value(loc=loc, ip=ip)) + else: + elem_attr = get_attr_for_type(lhs.type.element_type, rhs) + vect_attr = ir.DenseElementsAttr.get_splat(lhs.type, elem_attr) + vect_val = arith.constant(lhs.type, vect_attr, loc=loc, ip=ip) + rhs = TensorSSA(vect_val, lhs.shape, lhs.dtype) + + if flip: + lhs, rhs = rhs, lhs + + if op in ( + operator.lt, + operator.le, + operator.gt, + operator.ge, + operator.eq, + operator.ne, + ): + res_type = Boolean + + assert isinstance(rhs, TensorSSA), f"rhs must be TensorSSA but got {rhs}" + + def _broadcast(s, t): + if s == 1: + return t + elif t == 1: + return s + elif s == t: + return s + else: + raise ValueError(f"cannot broadcast {s} and {t}") + + max_rank = max(rank(lhs.shape), rank(rhs.shape)) + lhs_shape = append(lhs.shape, 1, up_to_rank=max_rank) + rhs_shape = append(rhs.shape, 1, up_to_rank=max_rank) + res_shape = transform_leaf(_broadcast, lhs_shape, rhs_shape) + + # broadcast to the same shape + lhs = lhs.broadcast_to(res_shape) + rhs = rhs.broadcast_to(res_shape) + + if ( + op in (operator.add, operator.sub) + and lhs.dtype == Boolean + and rhs.dtype == Boolean + ): + res = op(lhs.to(Int32), rhs.to(Int32)) + zero = zeros_like(res) + res = res.__ne__(zero).to(res_type) + else: + lhs_val = lhs.maybe_downcast() + rhs_val = rhs.maybe_downcast() + + if issubclass(lhs.dtype, Integer): + lhs_val = lhs_val.with_signedness(lhs.dtype.signed) + + if issubclass(rhs.dtype, Integer): + rhs_val = rhs_val.with_signedness(rhs.dtype.signed) + + res_vect = op(lhs_val, rhs_val) + res = TensorSSA(res_vect, lhs._shape, res_type) + + return res + + def broadcast_to(self, target_shape: Shape, *, loc=None, ip=None) -> "TensorSSA": + """ + Broadcast the tensor to the target shape. + """ + # pad source shape to the same rank + shape = append(self.shape, 1, up_to_rank=rank(target_shape)) + if shape == target_shape: + return self + + def _check_broadcast(s, t): + if s != t and s != 1: + raise ValueError( + f"src_shape and target_shape must be the same when src_shape is not 1, but got {s} and {t}" + ) + + transform_leaf(_check_broadcast, shape, target_shape) + + # reshape to flatten N-D vector + flat_shp = flatten_to_tuple(shape) + temp_ty = ir.VectorType.get(list(flat_shp), self.dtype.mlir_type) + temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) + + # broadcast to result N-D vector + flat_tgt_shp = flatten_to_tuple(target_shape) + temp_tgt_ty = ir.VectorType.get(list(flat_tgt_shp), self.dtype.mlir_type) + temp_tgt_vect = vector.broadcast(temp_tgt_ty, temp_vect, loc=loc, ip=ip) + + res_1d_ty = ir.VectorType.get([size(target_shape)], self.dtype.mlir_type) # type: ignore + res_1d_vect = vector.shape_cast(res_1d_ty, temp_tgt_vect, loc=loc, ip=ip) + + return TensorSSA(res_1d_vect, target_shape, self.dtype) + + def __pow__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the results of tensor^other. + + :param other: The other tensor for exponent. + :type other: TensorSSA + :return: The power of the tensor. + :rtype: TensorSSA + """ + return self._apply_op(operator.pow, other, loc=loc, ip=ip) + + def __rpow__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the results of other^tensor. + + :param other: The other tensor to compute power with. + :type other: TensorSSA + :return: The element-wise power of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.pow, other, flip=True, loc=loc, ip=ip) + + def __add__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the sum of the tensor and another tensor. + + :param other: The other tensor to add. + :type other: TensorSSA + :return: The sum of the two tensors with the same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.add, other, loc=loc, ip=ip) + + def __radd__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the sum of the tensor and another tensor (reverse add) + + :param other: The other tensor to add. + :type other: TensorSSA + :return: The sum of the two tensors with the same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.add, other, flip=True, loc=loc, ip=ip) + + def __sub__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the difference of the tensor and another tensor. + + :param other: The other tensor to subtract. + :type other: TensorSSA + :return: The subtraction of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.sub, other, loc=loc, ip=ip) + + def __rsub__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the difference of the tensor and another tensor (reverse subtract) + + :param other: The other tensor to subtract. + :type other: TensorSSA + :return: The subtraction of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.sub, other, flip=True, loc=loc, ip=ip) + + def __mul__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the multiplication of the tensor and another tensor. + + :param other: The other tensor to multiply. + :type other: TensorSSA + :return: The multiplication of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mul, other, loc=loc, ip=ip) + + def __rmul__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the multiplication of the tensor and another tensor (reverse multiply) + + :param other: The other tensor to multiply. + :type other: TensorSSA + :return: The multiplication of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mul, other, flip=True, loc=loc, ip=ip) + + def __mod__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the modulo of the tensor and another tensor. + + :param other: The other tensor to compute modulo with. + :type other: TensorSSA + :return: The element-wise modulo of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mod, other, loc=loc, ip=ip) + + def __rmod__(self, other) -> "TensorSSA": + """ + Returns the modulo of the tensor and another tensor (reverse modulo) + + :param other: The other tensor to compute modulo with. + :type other: TensorSSA + :return: The element-wise modulo of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mod, other, flip=True) + + def __floordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the floordiv(//) of the tensor and another tensor. + + :param other: The other tensor to compute floordiv with. + :type other: TensorSSA + :return: The floordiv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.floordiv, other, loc=loc, ip=ip) + + def __rfloordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the floordiv(//) of the tensor and another tensor (reverse floordiv) + + :param other: The other tensor to compute floordiv with. + :type other: TensorSSA + :return: The floordiv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.floordiv, other, flip=True, loc=loc, ip=ip) + + def __truediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the truediv(/) of the tensor and another tensor. + + :param other: The other tensor to compute truediv with. + :type other: TensorSSA + :return: The truediv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.truediv, other, loc=loc, ip=ip) + + def __rtruediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the truediv(/) of the tensor and another tensor (reverse truediv) + + :param other: The other tensor to compute truediv with. + :type other: TensorSSA + :return: The truediv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.truediv, other, flip=True, loc=loc, ip=ip) + + def __eq__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the comparison of the tensor and another tensor as mask + + :param other: The other tensor to compare. + :type other: TensorSSA + :return: The comparison of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.eq, other, loc=loc, ip=ip) + + def __ne__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise not equal comparison of the tensor and another tensor. + + :param other: The other tensor to compare. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self != other. + :rtype: TensorSSA + """ + return self._apply_op(operator.ne, other, loc=loc, ip=ip) + + def __lt__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise less than comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self < other. + :rtype: TensorSSA + """ + return self._apply_op(operator.lt, other, loc=loc, ip=ip) + + def __le__(self, other) -> "TensorSSA": + """ + Returns the element-wise less than or equal comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self <= other. + :rtype: TensorSSA + """ + return self._apply_op(operator.le, other) + + def __gt__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise greater than comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self > other. + :rtype: TensorSSA + """ + return self._apply_op(operator.gt, other) + + def __ge__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise greater than or equal comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self >= other. + :rtype: TensorSSA + """ + return self._apply_op(operator.ge, other, loc=loc, ip=ip) + + def __xor__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise XOR of the tensor and another tensor. + + :param other: The other tensor to perform XOR with. + :type other: TensorSSA + :return: The element-wise XOR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.xor, other) + + def __rxor__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the bitwise XOR of the tensor and another tensor. + + :param other: The other tensor to compute XOR with. + :type other: TensorSSA + :return: The element-wise bitwise XOR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.xor, other, flip=True, loc=loc, ip=ip) + + def __or__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise OR of the tensor and another tensor. + + :param other: The other tensor to perform OR with. + :type other: TensorSSA + :return: The element-wise OR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.or_, other) + + def __ror__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise OR of the tensor and another tensor. + + :param other: The other tensor to perform OR with. + :type other: TensorSSA + :return: The element-wise OR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.or_, other, flip=True) + + def __and__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise AND of the tensor and another tensor. + + :param other: The other tensor to perform AND with. + :type other: TensorSSA + :return: The element-wise AND of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.and_, other) + + def __rand__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise AND of the tensor and another tensor. + + :param other: The other tensor to perform AND with. + :type other: TensorSSA + :return: The element-wise AND of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.and_, other, flip=True, loc=loc, ip=ip) + + def __neg__(self, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the negation of the tensor. + + :return: The element-wise negation of the tensor + :rtype: TensorSSA + """ + + return self._apply_op(operator.sub, 0, flip=True, loc=loc, ip=ip) + + def _flatten_shape_and_coord(self, crd, *, loc=None, ip=None): + # Coalesce and flatten source layout at terminal of coordinate + # (N_0,(N_1,...), ...) -> (N_0,N_1,N_2,...) + crd_shp = product_like(self._shape, target_profile=crd, loc=loc, ip=ip) + + # Flatten coordinate + flat_shp = flatten(crd_shp) + assert isinstance(flat_shp, tuple) and is_static(flat_shp) + # (C_0,(C_1,...), ...) -> (C_0,C_1,C_2,...) + flat_crd = flatten(crd) + + assert isinstance(flat_crd, tuple) and is_static(flat_crd) + return flat_shp, flat_crd + + def _build_result(self, res_vect, res_shp, *, loc=None, ip=None): + if isinstance(res_shp, ir.Value): + raise ValueError( + f"expects static shape and coordinates, but got {self._shape} and {crd}" + ) + + # cast back to 1D vector + res_1d_ty = ir.VectorType.get([size(res_shp)], self.type.element_type) + res_1d_vect = vector.shape_cast(res_1d_ty, res_vect, loc=loc, ip=ip) + return TensorSSA(res_1d_vect, res_shp, self.dtype) + + @dsl_user_op + def __getitem__( + self, crd: Coord, *, loc=None, ip=None + ) -> Union["TensorSSA", Numeric]: + """Access or slice tensor elements using coordinates. + + This method implements tensor evaluation T(c) = *(E + L(c)) where E is the iterator/engine + and L is the layout. It supports both direct element access and slicing operations. + + :param crd: Coordinate or slice specification for accessing tensor elements + :type crd: Coord + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: Tensor element value or sliced subtensor + :rtype: Union[TensorSSA, Numeric] + + :raises ValueError: If coordinate access is invalid for the tensor layout + + **Examples:** + + .. code-block:: python + + # Create a fragment from rmem as shape (8, 4) + layout = make_layout((8, 4)) + tensor = make_fragment(layout, Float32) + frg = tensor.load() + + # Direct element access + val = frg[0] # Returns first element of fragment + val = frg[(0, 1)] # Returns element at (0, 1) + + # Slice access + sliced = frg[(3, None)] # Returns fragment slice + """ + # short-cut to no-op + if crd is None: + return self + + if not has_underscore(crd): + if self._layout is None: + self._layout = make_layout(self._shape, loc=loc, ip=ip) + idx = crd2idx(crd, self._layout, loc=loc, ip=ip) + idx_val = as_numeric(idx).ir_value(loc=loc, ip=ip) + res_val = vector.extractelement(self, position=idx_val, loc=loc, ip=ip) + return self.dtype(res_val) + + if not is_static(crd): + raise ValueError("dynamic coordinate is not supported") + + flat_shp, flat_crd = self._flatten_shape_and_coord(crd) + + multi_dim_ty = ir.VectorType.get(list(flat_shp), self.type.element_type) + # vector -> vector + tmp_vect = vector.shape_cast(multi_dim_ty, self) + + # Slice and keep dims matching `_` or None + res_shp = slice_(self._shape, crd) + if isinstance(res_shp, ir.Value): + raise TypeError( + f"expects static shape and coordinates, but got {self._shape} and {crd}" + ) + + # Offsets is index of coordinates if NOT `_` otherwise 0 + offsets = [c if c is not None else 0 for c in flat_crd] + # Sizes is size of shapes if `_` otherwise 1 + sizes = [s if c is None else 1 for s, c in zip(flat_shp, flat_crd)] + # Logic stride to index vector. Only support stride-1 by vector + strides = [1] * rank(flat_shp) + + # Vector slice on N-D vector + res_ty = ir.VectorType.get(list(sizes), self.type.element_type) + res_vect = vector.extract_strided_slice( + res_ty, tmp_vect, offsets=offsets, sizes=sizes, strides=strides + ) + + # Slice and keep dims matching `_` or None + res_shp = slice_(self._shape, crd) + return self._build_result(res_vect, res_shp, loc=loc, ip=ip) + + @dsl_user_op + def to(self, dtype: Type[Numeric], *, loc=None, ip=None): + """Convert the tensor to a different numeric type. + + :param dtype: The target numeric type to cast to. + :type dtype: Type[Numeric] + :return: A new tensor with the same shape but with elements cast to the target type. + :rtype: TensorSSA + :raises TypeError: If dtype is not a subclass of Numeric. + :raises NotImplementedError: If dtype is an unsigned integer type. + """ + if dtype is ir.Value: + return self + + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {type(dtype)}") + + src_dtype = self.dtype + if src_dtype == dtype: + return self + + # maybe downcast can lose signedness + src = self.maybe_downcast().with_signedness(self.signed) + if src_dtype.is_float and dtype.is_float: + res_vect = cutlass_arith.cvtf(src, dtype.mlir_type, loc=loc, ip=ip) + elif src_dtype.is_float and issubclass(dtype, Integer): + res_vect = cutlass_arith.fptoi( + src, dtype.signed, dtype.mlir_type, loc=loc, ip=ip + ) + elif issubclass(src_dtype, Integer) and dtype.is_float: + res_vect = cutlass_arith.itofp( + src, src_dtype.signed, dtype.mlir_type, loc=loc, ip=ip + ) + else: + res_vect = cutlass_arith.int_to_int(src, dtype, loc=loc, ip=ip) + + return TensorSSA(res_vect, self._shape, dtype) + + def ir_value(self, *, loc=None, ip=None): + return self + + def ir_value_int8(self, *, loc=None, ip=None): + """ + Returns int8 ir value of Boolean tensor. + When we need to store Boolean tensor ssa, use ir_value_int8(). + + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :return: The int8 value of this Boolean + :rtype: ir.Value + """ + assert ( + self.element_type is Boolean + ), f"Only boolean type needs to be converted to int8, got {self.element_type}" + + if not hasattr(self, "_value_int8"): + self._value_int8 = arith.extsi( + T.vector(self.type.shape[0], T.i8()), self, loc=loc, ip=ip + ) + return self._value_int8 + + def reduce(self, op, init_val, reduction_profile: Coord, *, loc=None, ip=None): + """ + Perform reduce on selected modes with given predefined reduction op. + + :param op: The reduction operator to use (operator.add or operator.mul) + :type op: operator + :param init_val: The initial value for the reduction + :type init_val: numeric + :param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with `None` are kept. + :type reduction_profile: Coord + + :return: The reduced tensor + :rtype: TensorSSA + + **Examples:** + + .. code-block:: python + + reduce(f32 o (4,)) + => f32 + + reduce(f32 o (4, 5)) + => f32 + reduce(f32 o (4, (5, 4)), reduction_profile=(None, 1)) + => f32 o (4,) + reduce(f32 o (4, (5, 4)), reduction_profile=(None, (None, 1))) + => f32 o (4, (5,)) + """ + # short-cut to no-op + if reduction_profile is None: + return self + + if not is_weakly_congruent(reduction_profile, self.shape): + raise ValueError( + f"Expect reduction_profile be weakly congruent to the shape of the tensor, " + f"but got {reduction_profile} and {self.shape}" + ) + + if op is ReductionOp.ADD: + red_kind = vector.CombiningKind.ADD + elif op is ReductionOp.MUL: + red_kind = vector.CombiningKind.MUL + elif op is ReductionOp.MAX: + red_kind = vector.CombiningKind.MAXIMUMF + elif op is ReductionOp.MIN: + red_kind = vector.CombiningKind.MINIMUMF + else: + raise NotImplementedError( + f"{op} is not supported, expects one of " + f"{ReductionOp.ADD, ReductionOp.MUL, ReductionOp.MAX, ReductionOp.MIN}" + ) + + elem_ty = self.element_type + # Canonicalize to `Numeric` and convert into MLIR value + init_val = as_numeric(init_val).ir_value(loc=loc, ip=ip) + + if depth(reduction_profile) == 0: + return vector.reduction( + elem_ty.mlir_type, red_kind, self, acc=init_val, loc=loc, ip=ip + ) + + flat_shp, flat_prof = self._flatten_shape_and_coord( + reduction_profile, loc=loc, ip=ip + ) + assert depth(flat_shp) == 1 and depth(flat_prof) == 1 + assert rank(flat_shp) == rank(flat_prof) + + temp_ty = ir.VectorType.get(list(flat_shp), elem_ty.mlir_type) + temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) + + if isinstance(flat_prof, tuple): + red_dims = [i for i, x in enumerate(flat_prof) if x is not None] + else: + red_dims = [0] + + temp_acc_shp = slice_(flat_shp, flat_prof, loc=loc, ip=ip) + temp_acc_ty = ir.VectorType.get(list(temp_acc_shp), elem_ty.mlir_type) + + init_val = vector.broadcast(temp_acc_ty, init_val, loc=loc, ip=ip) + res_vect = vector.multi_reduction( + red_kind, temp_vect, acc=init_val, reduction_dims=red_dims, loc=loc, ip=ip + ) + + # Slice and keep dims matching `_` or None + res_shp = slice_(self.shape, reduction_profile, loc=loc, ip=ip) + return self._build_result(res_vect, res_shp, loc=loc, ip=ip) + + +@dsl_user_op +def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> TensorSSA: + """ + Return a new TensorSSA of given shape and type, filled with fill_value. + + :param shape: Shape of the new tensor. + :type shape: tuple + :param fill_value: Value to fill the tensor with. + :type fill_value: scalar + :param dtype: Data type of the tensor. + :type dtype: Type[Numeric] + :return: Tensor of fill_value with the specified shape and dtype. + :rtype: TensorSSA + """ + size = product(shape, loc=loc, ip=ip) + if not is_static(size): + raise ValueError("shape must be static") + + if isinstance(fill_value, (ir.Value, int, float, bool)): + fill_value = dtype(fill_value) + elif isinstance(fill_value, Numeric): + fill_value = fill_value.to(dtype, loc=loc, ip=ip) + else: + raise ValueError(f"Expected fill_value be numeric type, but got {fill_value}") + + res_ty = T.vector(size, dtype.mlir_type) + res_val = vector.splat(res_ty, fill_value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + return TensorSSA(res_val, shape, dtype) + + +def full_like( + a: Union[TensorSSA, Tensor], + fill_value, + dtype: Union[None, Type[Numeric]] = None, + *, + loc=None, + ip=None, +) -> TensorSSA: + """ + Return a full TensorSSA with the same shape and type as a given array. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: array_like + :param fill_value: Fill value. + :type fill_value: array_like + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Union[None, Type[Numeric]], optional + :return: Tensor of `fill_value` with the same shape and type as `a`. + :rtype: TensorSSA + + .. seealso:: + :func:`empty_like`: Return an empty array with shape and type of input. + :func:`ones_like`: Return an array of ones with shape and type of input. + :func:`zeros_like`: Return an array of zeros with shape and type of input. + :func:`full`: Return a new array of given shape filled with value. + + **Examples:** + + .. code-block:: python + + frg = cute.make_fragment(Float32, (2, 3)) + a = frg.load() + b = cute.full_like(a, 1.0) + """ + if not hasattr(a, "shape"): + raise TypeError(f"Expect `a` be shaped type, but got {type(a)}") + + return full( + a.shape, fill_value, dtype if dtype is not None else a.dtype, loc=loc, ip=ip + ) + + +def empty_like(a, dtype=None): + """ + Return a new TensorSSA with the same shape and type as a given array, without initializing entries. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: TensorSSA + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Type[Numeric], optional + :return: Uninitialized tensor with the same shape and type (unless overridden) as `a`. + :rtype: TensorSSA + """ + return full_like(a, 0, dtype) + + +def ones_like(a, dtype=None): + """ + Return a TensorSSA of ones with the same shape and type as a given array. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: TensorSSA + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Type[Numeric], optional + :return: Tensor of ones with the same shape and type (unless overridden) as `a`. + :rtype: TensorSSA + """ + return full_like(a, 1, dtype) + + +def zeros_like(a, dtype=None, *, loc=None, ip=None): + """ + Return a TensorSSA of zeros with the same shape and type as a given array. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: TensorSSA + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Type[Numeric], optional + :return: Tensor of zeros with the same shape and type (unless overridden) as `a`. + :rtype: TensorSSA + """ + return full_like(a, 0, dtype, loc=loc, ip=ip) + + +def where( + cond: TensorSSA, x: TensorSSA, y: TensorSSA, *, loc=None, ip=None +) -> TensorSSA: + """ + Return elements chosen from x or y depending on condition. + + :param cond: Where True, yield x, where False, yield y. + :type cond: TensorSSA + :param x: Values from which to choose when condition is True. + :type x: TensorSSA + :param y: Values from which to choose when condition is False. + :type y: TensorSSA + :return: A tensor with elements from x where condition is True, and elements from y where condition is False. + :rtype: TensorSSA + """ + if x.dtype != y.dtype: + raise ValueError( + f"x and y must have the same dtype, but got {x.dtype} and {y.dtype}" + ) + + if cond.dtype != Boolean: + raise ValueError(f"cond must be Boolean type, but got {cond.dtype}") + + return TensorSSA( + arith.select(cond.ir_value(), x, y, loc=loc, ip=ip), x.shape, x.dtype + ) + + +def any_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: + """ + Test whether any tensor element evaluates to True. + + :param x: Input tensor. + :type x: TensorSSA + :return: Returns a TensorSSA scalar containing True if any element of x is True, False otherwise. + :rtype: TensorSSA + """ + is_true = x != full_like(x, 0, x.dtype, loc=loc, ip=ip) + return Boolean( + vector.reduction(T.bool(), vector.CombiningKind.OR, is_true, loc=loc, ip=ip) + ) + + +def all_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: + """ + Test whether all tensor elements evaluate to True. + + :param x: Input tensor. + :type x: TensorSSA + :return: Returns a TensorSSA scalar containing True if all elements of x are True, False otherwise. + :rtype: TensorSSA + """ + is_true = x != full_like(x, 0, x.dtype, loc=loc, ip=ip) + return Boolean( + vector.reduction(T.bool(), vector.CombiningKind.AND, is_true, loc=loc, ip=ip) + ) + + +############################################################################## +# User defined struct +############################################################################## + + +class struct: + """ + Decorator to abstract C structure in Python DSL. + + **Usage:** + + .. code-block:: python + + # Supports base_dsl scalar int/float elements, array and nested struct: + @cute.struct + class complex: + real : cutlass.Float32 + imag : cutlass.Float32 + + + @cute.struct + class StorageA: + mbarA : cute.struct.MemRange[cutlass.Int64, stage] + compA : complex + intA : cutlass.Int16 + + + # Supports aligment for its elements: + @cute.struct + class StorageB: + a: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, size_a], 1024 + ] + b: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, size_b], 1024 + ] + x: cute.struct.Align[cutlass.Int32, 16] + compA: cute.struct.Align[complex, 16] + + + # Statically get size and alignment: + size = StorageB.__sizeof__() + align = StorageB.__alignof__() + + # Allocate and referencing elements: + storage = allocator.allocate(StorageB) + + storage.a[0] ... + storage.x ... + storage.compA.real ... + + :param cls: The struct class with annotations. + :return: The decorated struct class. + """ + + # inner class for defining a continuous memory region + class _MemRangeMeta(type): + """ + A metaclass for creating MemRange classes. + + This metaclass is used to dynamically create MemRange classes with specific + data types and sizes. + + :ivar _dtype: The data type of the MemRange. + :ivar _size: The size of the MemRange. + """ + + _dtype = None + _size = None + + def __new__(cls, name, bases, dct): + new_cls = super().__new__(cls, name, bases, dct) + return new_cls + + def __getitem__(cls, params) -> Type["struct.MemRange"]: + # get params from syntax: struct.MemRange[dtype, size] + if len(params) == 2: + dtype, size = params + else: + raise TypeError("Invalid struct.MemRange Arguments") + + if not struct._is_scalar_type(dtype): + raise TypeError("MemRange only support dsl scalar type!") + + # Create new class with proper name and parameters + new_cls = type( + f"struct.MemRange[{dtype.__name__}, {size}]", + (struct.MemRange,), + {"_dtype": dtype, "_size": size}, + ) + return new_cls + + @property + def size(cls): + return cls._size + + @property + def elem_width(cls): + return cls._dtype.width + + @property + def size_in_bytes(cls): + return cls.size * cls.elem_width // 8 + + class MemRange(metaclass=_MemRangeMeta): + """ + Defines a range of memory by `MemRange[T, size]`. + """ + + pass + + class _MemRangeData: + """ + Represents a range of memory. + + :param dtype: The data type. + :param size: The size of the memory range in bytes. + :param base: The base address of the memory range. + """ + + def __init__(self, dtype, size, base): + """ + Initializes a new memory range. + + :param dtype: The data type. + :param size: Size of the memory range in bytes. A size of **0** is accepted, but in that + case the range can only be used for its address (e.g. as a partition marker). + :param base: The base address of the memory range. + """ + self._dtype = dtype + self._size = size + self._base = base + + def data_ptr(self): + """ + Returns start pointer to the data in this memory range. + + :return: A pointer to the start of the memory range. + :raises AssertionError: If the size of the memory range is negative. + """ + assert self._size >= 0 + return recast_ptr(self._base, dtype=self._dtype) + + def get_tensor(self, layout, swizzle=None, dtype=None): + """ + Creates a tensor from the memory range. + + :param layout: The layout of the tensor. + :param swizzle: Optional swizzle pattern. + :param dtype: Optional data type; defaults to the memory range's data type if not specified. + :return: A tensor representing the memory range. + :raises TypeError: If the layout is incompatible with the swizzle. + :raises AssertionError: If the size of the memory range is not greater than zero. + """ + assert self._size > 0 + # make tensor + if isinstance(layout, ComposedLayout) and (swizzle is not None): + raise TypeError(f"incompatible layout with swizzle") + elem_type = self._dtype if dtype is None else dtype + ptr = recast_ptr(self._base, swizzle, dtype=elem_type) + res = make_tensor(ptr, layout) + return res + + def __getitem__(self, index: int) -> Any: + """ + Returns the element at the specified index in the memory range. + + :param index: The index of the element to retrieve. + :return: The element at the specified index. + :raises AssertionError: If the index is out of range. + """ + assert (index >= 0) and (index < self._size) + return self.data_ptr() + index + + # inner class for aligning a member type + class _AlignMeta(type): + """ + Aligns the given object by setting its alignment attribute. + + :param v: The object to align. Must be a struct, MemRange, or a scalar type. + :param align: The alignment value to set. + :raises TypeError: If the object is not a struct, MemRange, or a scalar type. + + :ivar _dtype: The data type to be aligned. + :ivar _align: The alignment of the data type. + """ + + _dtype = None + _align = None + + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct) + + def __getitem__(cls, params) -> Any: + if len(params) == 2: + dtype, align = params + assert align > 0 + else: + raise TypeError("Invalid struct.Align Arguments") + + if not struct._is_scalar_type(dtype) and not isinstance( + dtype, (struct, struct._MemRangeMeta) + ): + raise TypeError( + "align only can be applied to struct/MemRange/base_dsl scalar" + ) + + # Create new class with alignment + new_cls = type( + f"struct.Align[{dtype.__name__}, {align}]", + (struct.Align,), + {"_dtype": dtype, "_align": align}, + ) + return new_cls + + @property + def dtype(cls): + return cls._dtype + + @property + def align(cls): + return cls._align + + class Align(metaclass=_AlignMeta): + """ + Aligns the given type by `Align[T, alignment]`. + """ + + pass + + # util func for base dsl scalar types + @staticmethod + def _is_scalar_type(dtype): + """ + Checks if the given type is a scalar numeric type. + + :param dtype: The type to check. + :return: True if the type is a subclass of Numeric, False otherwise. + """ + return isinstance(dtype, type) and issubclass(dtype, Numeric) + + # calculate size and alignment + def __init__(self, cls): + """ + Initializes a new struct decorator instance. + + :param cls: The class representing the structured data type. + :raises TypeError: If the struct is empty. + """ + self._cls = cls + self.__name__ = f"struct::{cls.__name__}" + # Get the class annotations + self._annotations = cls.__annotations__ + # Create a dictionary to store the offsets + self._offsets: Dict[str, int] = {} + + # Calculate the offsets and alignment + offset = 0 + alignment = 1 + if len(self._annotations) == 0: + raise TypeError("Empty struct is not supported!") + for name, object in self._annotations.items(): + # get alignment of object + sub_align = 1 + if isinstance(object, struct._AlignMeta): + sub_align = object.align + object = object.dtype + + # switch addition order to support dynamic size + def add_offset(val): + return val + offset if isinstance(val, ir.Value) else offset + val + + # size of scalar + if struct._is_scalar_type(object): + dtype_size = max(1, object.width // 8) + sub_align = max(dtype_size, sub_align) + offset = self.align_offset(offset, sub_align) + self._offsets[name] = offset + offset = add_offset(dtype_size) + # size of array is size_in_bytes, alignment is elem_size + elif isinstance(object, struct._MemRangeMeta): + # Allow empty array as a free marker-only struct member. + # Use max(sub_align, ) because we might have in the future some + # object.elem_width less than 8, such as fp4, bit and others, + # and align_offset() does not support an alignment of 0. + sub_align = max(object.elem_width // 8, sub_align) + offset = self.align_offset(offset, sub_align) + self._offsets[name] = offset + offset = add_offset(object.size_in_bytes) + # size of struct + elif isinstance(object, struct): + sub_align = max(object.__alignof__(), sub_align) + offset = self.align_offset(offset, sub_align) + self._offsets[name] = offset + offset = add_offset(object.__sizeof__()) + else: + raise TypeError( + f"Struct element only support struct/array/base_dsl scalar, " + f"but got {object}" + ) + # Total aligment determined by the strictest requirement + alignment = max(alignment, sub_align) + # Total size determined by alignment + self._align_of = alignment + self._size_of = self.align_offset(offset, alignment) + + # create the __init__ method for decorated struct + def __call__(self, base: Any) -> None: + """ + Creates a new instance of the decorated struct. + + :param base: The base address of the struct. + :return: An instance of the decorated struct. + :raises TypeError: If the base pointer is not byte-sized. + """ + if base.type.value_type.width != 8: + raise TypeError("struct base ptr value type must be byte sized.") + # make an new object of user-defined decorated struct + # otherwise it will override same self._cls when new instance created + cls = self._cls() + setattr(cls, "_base", base) + for name, off in self._offsets.items(): + obj = self._annotations[name] + if isinstance(obj, struct._AlignMeta): + obj = obj.dtype + if struct._is_scalar_type(obj): + new_obj = recast_ptr(base + off, dtype=obj) + setattr(cls, name, new_obj) + elif isinstance(obj, struct._MemRangeMeta): + new_obj = struct._MemRangeData(obj._dtype, obj._size, base + off) + setattr(cls, name, new_obj) + elif isinstance(obj, struct): + new_obj = obj(base + off) + setattr(cls, name, new_obj) + else: + raise TypeError( + f"Struct element only support struct/array/base_dsl scalar, " + f"but got {obj}" + ) + return cls + + # get size + def size_in_bytes(self) -> int: + """ + Returns the size of the struct in bytes. + + :return: The size of the struct. + """ + return self._size_of + + # get size + def __sizeof__(self) -> int: + return self._size_of + + # get alignment + def __alignof__(self) -> int: + return self._align_of + + # util func for aligning offset + @staticmethod + def align_offset(offset, align): + """ + Return the round-up offset up to the next multiple of align. + """ + assert align > 0 and not ( + align & (align - 1) + ), "align should be a strictly positive power of 2." + return (offset + (align - 1)) & ~(align - 1) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py new file mode 100644 index 0000000000000000000000000000000000000000..daaa608262d00268ec1c47dfe32758c555f009b0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py @@ -0,0 +1,445 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .core import TensorSSA +from .typing import Numeric +from cutlass._mlir.dialects import math, arith + +from typing import Callable, Union + + +def _math_op(func: Callable, fastmath: bool, *args, **kwargs): + """Dispatch the function to either a TensorSSA or a Numeric(Float). + + :param func: The function to dispatch + :param args: The input tensor or scalar + :param kwargs: The input tensor or scalar + """ + arg_type = type(args[0]) + for arg in args: + if not isinstance(arg, TensorSSA) and ( + not isinstance(arg, Numeric) or not type(arg).is_float + ): + raise TypeError( + f"Expected a TensorSSA or Numeric(Float), but got {type(arg)}" + ) + if not isinstance(arg, arg_type): + raise TypeError( + f"Expected all inputs to be of type {arg_type}, but got {type(arg)}" + ) + + fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none + if isinstance(args[0], TensorSSA): + return TensorSSA( + func(*args, fastmath=fastmath_flag), args[0].shape, args[0].dtype + ) + else: + args = [a.ir_value() for a in args] + return func(*args, fastmath=fastmath_flag) + + +def acos( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise arc cosine of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the arc cosine of each element in input tensor + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = acos(y) # Compute arc cosine + """ + return _math_op(math.acos, fastmath, a) + + +def asin( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise arc sine of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the arc sine of each element in input tensor + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = asin(y) # Compute arc sine + """ + return _math_op(math.asin, fastmath, a) + + +def atan( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise arc tangent of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the arc tangent of each element in input tensor + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = atan(y) # Compute arc tangent + """ + raise NotImplementedError("atan is not implemented") + return _math_op(math.atan, fastmath, a) + + +def atan2( + a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise arc tangent of two tensors. + + Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians + between the positive x-axis and the point given by the coordinates (b, a). + + :param a: First input tensor (y-coordinates) + :type a: Union[TensorSSA, Numeric] + :param b: Second input tensor (x-coordinates) + :type b: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the arc tangent of a/b element-wise + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + y = cute.make_fragment(ptr1, layout).load() # y coordinates + x = cute.make_fragment(ptr2, layout).load() # x coordinates + theta = atan2(y, x) # Compute angles + """ + return _math_op(math.atan2, fastmath, a, b) + + +def cos( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise cosine of the input tensor. + + :param a: Input tensor (in radians) + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the cosine of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = cos(y) # Compute cosine + """ + return _math_op(math.cos, fastmath, a) + + +def erf( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise error function of the input tensor. + + The error function is defined as: + erf(x) = 2/√π ∫[0 to x] exp(-t²) dt + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the error function value for each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = erf(y) # Compute error function + """ + return _math_op(math.erf, fastmath, a) + + +def exp( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise exponential of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the exponential of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = exp(y) # Compute exponential + """ + return _math_op(math.exp, fastmath, a) + + +def exp2( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise base-2 exponential of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing 2 raised to the power of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = exp2(y) # Compute 2^x + """ + return _math_op(math.exp2, fastmath, a) + + +def log( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise natural logarithm of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the natural logarithm of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = log(y) # Compute natural logarithm + """ + return _math_op(math.log, fastmath, a) + + +def log2( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise base-2 logarithm of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the base-2 logarithm of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = log2(y) # Compute log base 2 + """ + return _math_op(math.log2, fastmath, a) + + +def log10( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise base-10 logarithm of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the base-10 logarithm of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = log10(y) # Compute log base 10 + """ + return _math_op(math.log10, fastmath, a) + + +def rsqrt( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise reciprocal square root of the input tensor. + + Computes 1/√x element-wise. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the reciprocal square root of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = rsqrt(y) # Compute 1/√x + """ + return _math_op(math.rsqrt, fastmath, a) + + +def sin( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise sine of the input tensor. + + :param a: Input tensor (in radians) + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the sine of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = sin(y) # Compute sine + """ + return _math_op(math.sin, fastmath, a) + + +def sqrt( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise square root of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the square root of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = sqrt(y) # Compute square root + """ + return _math_op(math.sqrt, fastmath, a) + + +def tan( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise tangent of the input tensor. + + :param a: Input tensor (in radians) + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the tangent of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = tan(y) # Compute tangent + """ + return _math_op(math.tan, fastmath, a) + + +def tanh( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise hyperbolic tangent of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the hyperbolic tangent of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = tanh(y) # Compute hyperbolic tangent + """ + return _math_op(math.tanh, fastmath, a) + + +__all__ = [ + "acos", + "asin", + "atan", + "atan2", + "cos", + "erf", + "exp", + "exp2", + "log", + "log10", + "log2", + "rsqrt", + "sin", + "sqrt", + "tan", + "tanh", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0655bb09c05ae84714656020127cb41a4f28fbf6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from . import warp +from . import cpasync +from . import warpgroup +from . import tcgen05 + +from .common import * +from .helpers import * + + +# __all__ is required here for documentation generation +__all__ = [ + "OpError", + "MmaUniversalOp", + "CopyUniversalOp", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0c4c82debcd55cd7f3d7df0e21920cda83ca18 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. +import enum +from dataclasses import dataclass +from typing import Type, Optional + +from cutlass.cutlass_dsl import DSLBaseError + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from .. import core +from ..typing import Float16, Float32, Float64, Numeric + + +class OpError(DSLBaseError): + """ + An exception class for Op construction errors. + """ + + def __init__( + self, op: core.Op, message: str, suggestion: Optional[str] = None + ) -> None: + if suggestion is None: + # Default suggestion + suggestion = "Check your Op construction code" + super().__init__( + message, + error_code=f"{op.__class__.__name__} error", + suggestion=suggestion, + ) + + +#################################################################################################### +# +# MMA Ops and Traits +# +#################################################################################################### + + +@dataclass(frozen=True) +class MmaUniversalOp(core.MmaOp): + """ + The universal MMA Operation. + + This Operation currently expects the A/B operands as well as the accumulator to share the same + data types. + + :param abacc_dtype: The data type for the A/B operands and the accumulator + :type abacc_dtype: Type[Numeric] + """ + + abacc_dtype: Type[Numeric] + + def __post_init__(self) -> None: + if self.abacc_dtype not in [Float16, Float32, Float64]: + raise OpError( + self, + f"expects the 'abacc_dtype' Op parameter to be one of Float16, Float32, or Float64", + ) + + def __str__(self) -> str: + return ( + "universal MMA Operation using FMA" + f"\n A/B/Accumulator data type = {self.abacc_dtype}" + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaUniversalTrait": + shape_mnk_attr = ir.Attribute.parse(f'#cute.shape<"(1,1,1)">') + atom_ty = _cute_nvgpu_ir.UniversalFmaAtomType.get( + shape_mnk_attr, + self.abacc_dtype.mlir_type, + self.abacc_dtype.mlir_type, + self.abacc_dtype.mlir_type, + ) + return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip)) + + def _verify_fragment_A(self, input, *, loc=None, ip=None): + pass + + def _verify_fragment_B(self, input, *, loc=None, ip=None): + pass + +class MmaUniversalTrait(core.Trait): + pass + + +#################################################################################################### +# +# Copy Ops and Traits +# +#################################################################################################### + + +class MemoryOrder(enum.Enum): + WEAK = _cute_ir.MemOrderKind.WEAK + RELAXED = _cute_ir.MemOrderKind.RELAXED + ACQUIRE = _cute_ir.MemOrderKind.ACQUIRE + RELEASE = _cute_ir.MemOrderKind.RELEASE + ACQ_REL = _cute_ir.MemOrderKind.ACQ_REL + SC = _cute_ir.MemOrderKind.SC + MMIO = _cute_ir.MemOrderKind.MMIO + CONSTANT = _cute_ir.MemOrderKind.CONSTANT + VOLATILE = _cute_ir.MemOrderKind.VOLATILE + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.MemOrderKind: + return self.value + + +class MemoryScope(enum.Enum): + CTA = _cute_ir.MemScopeKind.CTA + CLUSTER = _cute_ir.MemScopeKind.CLUSTER + GPU = _cute_ir.MemScopeKind.GPU + SYS = _cute_ir.MemScopeKind.SYS + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.MemScopeKind: + return self.value + +@dataclass(frozen=True) +class CopyUniversalOp(core.CopyOp): + """ + The universal Copy Operation. + + When creating a Copy Atom out of this operation, the expected usage pattern is + + .. code-block:: python + + op = cute.nvgpu.CopyUniversalOp() + atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64) + + - ``tensor_dtype`` is the data type used to build the reference TV Layout (either the source \ + or the destination TV Layout) in unit of tensor elements and is used for partitioning by \ + ``TiledCopy`` for example + - ``num_bits_per_copy`` is a kw argument specifying the number of bits to copy per Atom \ + execution. This can be larger than the width of the above data type. When not provided, \ + the compiler will do a best effort at auto-vectorizing. + """ + + def __str__(self) -> str: + return "universal Copy Operation" + + def _make_trait( + self, + copy_internal_type: Type[Numeric], + *, + loc=None, + ip=None, + **kwargs, + ) -> "CopyUniversalTrait": + num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) + memory_order = kwargs.get("memory_order", MemoryOrder.WEAK) + memory_scope = kwargs.get("memory_scope", MemoryScope.CTA) + if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): + raise ValueError( + "expects a 'num_bits_per_copy' kw argument of type int that is non-negative " + f"when creating a copy Atom for {self.__class__.__name__}" + ) + ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( + copy_internal_type.mlir_type, + num_bits_per_copy, + memory_order._to_ir(), + memory_scope._to_ir(), + ) + return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class CopyUniversalTrait(core.Trait): + pass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..246360c2eb43ed5c4ca45127c579bc9f496caa08 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .copy import * +from .helpers import * + + +# __all__ is required here for documentation generation +__all__ = [ + # + # copy.py + # + "LoadCacheMode", + "CopyG2SOp", + "CopyBulkTensorTileG2SOp", + "CopyBulkTensorTileG2SMulticastOp", + "CopyBulkTensorTileS2GOp", + "CopyReduceBulkTensorTileS2GOp", + # + # helpers.py + # + "make_tiled_tma_atom", + "tma_partition", + "create_tma_multicast_mask", + "prefetch_descriptor", + "copy_tensormap", + "update_tma_descriptor", + "fence_tma_desc_acquire", + "cp_fence_tma_desc_release", + "fence_tma_desc_release", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py new file mode 100644 index 0000000000000000000000000000000000000000..a15495602304700d19803825d93004e0fa9fc509 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py @@ -0,0 +1,471 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from dataclasses import dataclass +from typing import Optional, Type + +from cutlass.cutlass_dsl import CuTeDSL, t + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ...core import CopyOp, Trait, ReductionOp +from ...typing import Int16, Pointer, Integer, Numeric +from ..common import OpError +from ..tcgen05.mma import CtaGroup + + +#################################################################################################### +# +# Aynchronous copies +# +#################################################################################################### + + +class LoadCacheMode(enum.Enum): + """ + An enumeration for the possible cache modes of a non-bulk ``cp.async`` instruction. + + See the `PTX documentation `__. + """ + + ALWAYS = _cute_nvgpu_ir.LoadCacheMode.always + GLOBAL = _cute_nvgpu_ir.LoadCacheMode.global_ + STREAMING = _cute_nvgpu_ir.LoadCacheMode.streaming + LAST_USE = _cute_nvgpu_ir.LoadCacheMode.last_use + NONE = _cute_nvgpu_ir.LoadCacheMode.none + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_nvgpu_ir.LoadCacheMode: + return self.value + + +@dataclass(frozen=True) +class CopyG2SOp(CopyOp): + """ + Non-bulk asynchronous GMEM to SMEM Copy Operation. + + See the `PTX documentation `__. + """ + + cache_mode: LoadCacheMode = LoadCacheMode.ALWAYS + + def __str__(self) -> str: + res = "cp.async GMEM -> SMEM copy Operation" + if self.cache_mode != LoadCacheMode.ALWAYS: + res += f"\n with cache mode = {self.cache_mode}" + return res + + def _make_trait( + self, + copy_internal_type: Type[t.Numeric], + *, + loc=None, + ip=None, + **kwargs, + ) -> "CopyG2STrait": + num_bits_per_copy = kwargs.get("num_bits_per_copy", None) + # Verify that the user provided enum values + if not isinstance(self.cache_mode, LoadCacheMode): + raise OpError( + self, + "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance", + ) + if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0): + raise ValueError( + "expects a 'num_bits_per_copy' kw argument of type int that is positive " + f"when creating a copy Atom for {self.__class__.__name__}" + ) + # Verify that the user provided enum values + if not isinstance(self.cache_mode, LoadCacheMode): + raise OpError( + self, + "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance", + ) + ty = _cute_nvgpu_ir.CopyAtomSIMTAsyncCopyType.get( + copy_internal_type.mlir_type, self.cache_mode._to_ir(), num_bits_per_copy + ) + return CopyG2STrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class CopyG2STrait(Trait): + pass + + +#################################################################################################### +# +# Bulk tensor copies a.k.a TMA copies +# +#################################################################################################### + +TMA_MBAR_PTR_FIELD_NAME = "tma_bar" +TMA_MASK_FIELD_NAME = "mcast_mask" +TMA_DESC_PTR_FIELD_NAME = "tma_descriptor_ptr" + +# +# TMA GMEM -> SMEM copies +# + + +@dataclass(frozen=True) +class CopyBulkTensorTileG2SOp(CopyOp): + """ + Bulk tensor asynchrnous GMEM to SMEM Copy Operation using the TMA unit. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.tile`` mode. + """ + + cta_group: CtaGroup = CtaGroup.ONE + + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, "expects the 'cta_group' parameter to be a CtaGroup instance" + ) + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90": + raise OpError( + self, + f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + res = "cp.async GMEM -> SMEM bulk tensor copy Operation" + if self.cta_group == CtaGroup.TWO: + res += f"\n CTA group = 2" + return res + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyBulkTensorTileG2SNonExecTrait": + raise NotImplementedError( + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" + ) + + def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum: + if self.cta_group == CtaGroup.ONE: + return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90 + elif self.cta_group == CtaGroup.TWO: + return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm + else: + assert False, "unrecognized self.cta_group" + + +class CopyBulkTensorTileG2SNonExecTrait(Trait): + # We allow kw args to be dropped so that the user can write common code for non-multicast + # and multicast loads. + def unpack( + self, + *, + loc=None, + ip=None, + tma_bar_ptr: Optional[Pointer] = None, + tma_desc_ptr: Optional[Pointer] = None, + **kwargs, + ): + """ + Custom implementation of unpack for non-executable TMAs. + + The non-multicast TMA load requires a `tma_bar_ptr` keyword argument to be provided when + using `cute.copy`. Any other kw arguments will be ignored instead of triggering an error. + """ + if not isinstance(tma_bar_ptr, Pointer): + raise ValueError( + "expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument" + ) + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_MBAR_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip + ) + if isinstance(tma_desc_ptr, Pointer): + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + return exec_value + + +# +# TMA GMEM -> SMEM multicast copies +# + + +@dataclass(frozen=True) +class CopyBulkTensorTileG2SMulticastOp(CopyOp): + """ + Bulk tensor asynchrnous multicast GMEM to SMEM Copy Operation using the TMA unit. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.tile`` mode. + """ + + cta_group: CtaGroup = CtaGroup.ONE + + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] + + def __post_init__(self): + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, "expects the 'cta_group' parameter to be a CtaGroup instance" + ) + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90": + raise OpError( + self, + f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + res = "cp.async GMEM -> SMEM bulk tensor multicast copy Operation" + if self.cta_group == CtaGroup.TWO: + res += f"\n CTA group = 2" + return res + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyBulkTensorTileG2SMulticastNonExecTrait": + raise NotImplementedError( + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" + ) + + def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum: + if self.cta_group == CtaGroup.ONE: + return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90_multicast + elif self.cta_group == CtaGroup.TWO: + return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm_multicast + else: + assert False, "unrecognized self.cta_group" + + +class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait): + def unpack( + self, + *, + loc=None, + ip=None, + tma_bar_ptr: Optional[Pointer] = None, + mcast_mask=None, + tma_desc_ptr=None, + ): + """ + Custom implementation of unpack for non-executable TMAs. + + The multicast TMA load requires a `tma_bar_ptr` and a `mcast_mask` keyword arguments to be + provided when using `cute.copy`. + """ + if not isinstance(tma_bar_ptr, Pointer): + raise ValueError( + "expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument" + ) + if not isinstance(mcast_mask, Integer): + raise ValueError( + "expects a multicast mask to be provided via the mcast_mask kw argument" + ) + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip + ) + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, Int16(mcast_mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + if isinstance(tma_desc_ptr, Pointer): + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + return exec_value + + +# +# TMA SMEM -> GMEM copies +# + + +@dataclass(frozen=True) +class CopyBulkTensorTileS2GOp(CopyOp): + """ + Bulk tensor asynchronous SMEM to GMEM Copy Operation using the TMA unit. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.tile`` mode. + """ + + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] + + def __post_init__(self): + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + return "cp.async SMEM -> GMEM bulk tensor copy Operation" + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyBulkTensorTileS2GTrait": + raise NotImplementedError( + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" + ) + + +class CopyBulkTensorTileS2GTrait(Trait): + def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None): + """ + Custom implementation of unpack for non-executable TMAs. + """ + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) + if isinstance(tma_desc_ptr, Pointer): + attr_str = ( + f"#cute_nvgpu.atom_copy_field_tmastore<{TMA_DESC_PTR_FIELD_NAME}>" + ) + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + return exec_value + +@dataclass(frozen=True) +class CopyReduceBulkTensorTileS2GOp(CopyOp): + """ + Bulk tensor asynchronous SMEM to GMEM Reduction Operation using the TMA unit. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.tile`` mode. + """ + + reduction_kind: ReductionOp = ReductionOp.ADD + + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] + + def __post__init__(self): + # Arch verification + arch = CuTeDSL.__get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + return "cp.async SMEM -> GMEM bulk tensor reduction Operation" + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyReduceBulkTensorTileS2GTrait": + raise NotImplementedError( + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" + ) + + def _to_ir(self) -> _cute_nvgpu_ir.ReductionKind: + if self.reduction_kind == ReductionOp.ADD: + return _cute_nvgpu_ir.ReductionKind.ADD + elif self.reduction_kind == ReductionOp.MIN: + return _cute_nvgpu_ir.ReductionKind.MIN + elif self.reduction_kind == ReductionOp.MAX: + return _cute_nvgpu_ir.ReductionKind.MAX + elif self.reduction_kind == ReductionOp.INC: + return _cute_nvgpu_ir.ReductionKind.INC + elif self.reduction_kind == ReductionOp.DEC: + return _cute_nvgpu_ir.ReductionKind.DEC + elif self.reduction_kind == ReductionOp.AND: + return _cute_nvgpu_ir.ReductionKind.AND + elif self.reduction_kind == ReductionOp.OR: + return _cute_nvgpu_ir.ReductionKind.OR + elif self.reduction_kind == ReductionOp.XOR: + return _cute_nvgpu_ir.ReductionKind.XOR + else: + assert False, "unrecognized self.reduction_kind" + + +class CopyReduceBulkTensorTileS2GTrait(Trait): + def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None): + """ + Custom implementation of unpack for non-executable TMAs. + """ + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) + if isinstance(tma_desc_ptr, Pointer): + attr_str = ( + f"#cute_nvgpu.atom_copy_field_tmareduce<{TMA_DESC_PTR_FIELD_NAME}>" + ) + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + return exec_value + +__all__ = [ + "LoadCacheMode", + "CopyG2SOp", + "CopyBulkTensorTileG2SOp", + "CopyBulkTensorTileG2SMulticastOp", + "CopyBulkTensorTileS2GOp", + "CopyReduceBulkTensorTileS2GOp", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..f64f07f167501d1805096373e915017612de4387 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py @@ -0,0 +1,341 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Optional, Tuple, Type, Union + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects import llvm + +from ...typing import Coord, Layout, Tensor, Tiler, Pointer, Int16, Numeric, NumericMeta +from ... import core +from .copy import ( + CopyBulkTensorTileG2SOp, + CopyBulkTensorTileG2SMulticastOp, + CopyBulkTensorTileS2GOp, + CopyReduceBulkTensorTileS2GOp, + CopyBulkTensorTileG2SNonExecTrait, + CopyBulkTensorTileG2SMulticastNonExecTrait, + CopyBulkTensorTileS2GTrait, + CopyReduceBulkTensorTileS2GTrait, +) + + +@dsl_user_op +def make_tiled_tma_atom( + op: Union[ + CopyBulkTensorTileG2SOp, + CopyBulkTensorTileG2SMulticastOp, + CopyBulkTensorTileS2GOp, + CopyReduceBulkTensorTileS2GOp, + ], + gmem_tensor: Tensor, + smem_layout: Union[Layout, core.ComposedLayout], + cta_tiler: Tiler, + num_multicast: int = 1, + *, + internal_type: Optional[Type[Numeric]] = None, + loc=None, + ip=None, +) -> Tuple[core.CopyAtom, Tensor]: + """ + Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from SMEM + buffer with the given Layout. + + Given + + - a GMEM tensor + - a SMEM layout + - a CTA-level Tiler + + this function figures out the bulk tensor asynchronous copy instruction to use with the maximum + "TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided + layout and consistent with the provided Tiler. + + This function returns two results: + + 1. the Copy Atom + 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates \ + that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the \ + associated layout can output coordinates. Otherwise, TMA tensors can be partitioned \ + similarly to any other CuTe tensors using the algebra. + + :param op: The Copy Operation to construct an Atom for + :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, CopyReduceBulkTensorTileS2GOp] + :param gmem_tensor: The GMEM tensor involved in the Copy + :type gmem_tensor: Tensor + :param smem_layout: The SMEM layout to construct the Copy Atom for + :type smem_layout: Union[Layout, core.ComposedLayout] + :param cta_tiler: The CTA Tiler to use + :type cta_tiler: Tiler + :param num_multicast: The multicast factor + :type num_multicast: int + :param internal_type: An optional parameter for the internal data type to use when the actual data type is not supported by the TMA unit + :type internal_type: Type[Numeric] + :return: A Copy Atom for this Operation and the associated TMA tensor + :rtype: Tuple[core.CopyAtom, Tensor] + """ + + if internal_type is not None: + if not isinstance(internal_type, NumericMeta): + raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") + internal_type = internal_type.mlir_type + + cta_v_map = core.composition( + core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip), + cta_tiler, + loc=loc, + ip=ip, + ) + + if isinstance(op, CopyBulkTensorTileG2SOp): + if num_multicast != 1: + raise ValueError( + f"expects num_multicast to be 1 for non multicast G2S copies, " + f"but got {num_multicast}" + ) + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + num_multicast=num_multicast, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] + elif isinstance(op, CopyBulkTensorTileG2SMulticastOp): + if num_multicast < 1: + raise ValueError( + f"expects num_multicast to be >= 1 for multicast G2S copies, " + f"but got {num_multicast}" + ) + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + num_multicast=num_multicast, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + return ( + core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + res[1], + ) + elif isinstance(op, CopyBulkTensorTileS2GOp): + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_store( + gmem_tensor.value, + smem_layout, + cta_v_map, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1] + elif isinstance(op, CopyReduceBulkTensorTileS2GOp): + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_reduce( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + internal_type=internal_type, + loc=loc, + ip=ip, + ) + return core.CopyAtom(op, CopyReduceBulkTensorTileS2GTrait(res[0])), res[1] + else: + raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}") + + +@dsl_user_op +def tma_partition( + atom: core.CopyAtom, + cta_coord: Coord, + cta_layout: Layout, + smem_tensor: Tensor, + gmem_tensor: Tensor, + *, + loc=None, + ip=None, +) -> Tuple[Tensor, Tensor]: + """ + Tiles the GMEM and SMEM tensors for the provided TMA Copy Atom. + """ + cta_coord_val = core._pack_coord(cta_coord, loc=loc, ip=ip) + s, d = _cute_nvgpu_ir.atom_tma_partition( + atom._trait.value, + cta_coord=cta_coord_val, + cta_layout=cta_layout, + smem_tensor=smem_tensor.value, + gmem_tensor=gmem_tensor.value, + loc=loc, + ip=ip, + ) + return s, d + + +@dsl_user_op +def create_tma_multicast_mask( + cta_layout_vmnk: Layout, + cta_coord_vmnk: Coord, + mcast_mode: int, + *, + loc=None, + ip=None, +) -> Int16: + """ + Computes a multicast mask for a TMA load Copy. + + :param cta_layout_vmnk: The VMNK layout of the cluster + :type cta_layout_vmnk: Layout + :param cta_coord_vmnk: The VMNK coordinate of the current CTA + :type cta_coord_vmnk: Coord + :param mcast_mode: The tensor mode in which to multicast + :type mcast_mode: int + :return: The resulting mask + :rtype: Int16 + """ + if core.rank(cta_layout_vmnk) != 4: + raise ValueError( + f"cta_layout_vmnk must be rank 4, but got {core.pretty_str(cta_layout_vmnk)}" + ) + if core.rank(cta_coord_vmnk) != 4: + raise ValueError( + f"cta_coord_vmnk must be rank 4, but got {core.pretty_str(cta_coord_vmnk)}" + ) + return core.make_layout_image_mask( + cta_layout_vmnk, cta_coord_vmnk, mcast_mode, loc=loc, ip=ip + ) + + +@dsl_user_op +def prefetch_descriptor(tma_atom: core.CopyAtom, *, loc=None, ip=None) -> None: + """ + Prefetches the TMA descriptor associated with the TMA Atom. + """ + _cute_nvgpu_ir.prefetch_tma_desc(tma_atom._trait.value, loc=loc, ip=ip) + + +@dsl_user_op +def copy_tensormap( + tma_atom: core.CopyAtom, tensormap_ptr: Pointer, *, loc=None, ip=None +) -> None: + """ + Copies the tensormap held by a TMA Copy Atom to the memory location pointed to by the provided + pointer. + + :param tma_atom: The TMA Copy Atom + :type tma_atom: CopyAtom + :param tensormap_ptr: The pointer to the memory location to copy the tensormap to + :type tensormap_ptr: Pointer + """ + _cute_nvgpu_ir.copy_tma_desc( + tma_atom._trait.value, tensormap_ptr.value, loc=loc, ip=ip + ) + + +@dsl_user_op +def update_tma_descriptor( + tma_atom: core.CopyAtom, + gmem_tensor: Tensor, + tma_desc_ptr: Pointer, + *, + loc=None, + ip=None, +) -> None: + """ + Updates the TMA descriptor in the memory location pointed to by the provided pointer using + information from a TMA Copy Atom and the provided GMEM tensor. + + Specifically, the following fields of the TMA descriptor will be updated: + + 1. the GMEM tensor base address + 2. the GMEM tensor shape + 3. the GMEM tensor stride + + Other fields of the TMA descriptor are left unchanged. + + :param tma_atom: The TMA Copy Atom + :type tma_atom: CopyAtom + :param gmem_tensor: The GMEM tensor + :type gmem_tensor: Tensor + :param tensormap_ptr: The pointer to the memory location of the descriptor to udpate + :type tensormap_ptr: Pointer + """ + _cute_nvgpu_ir.update_tma_desc( + tma_atom._trait.value, gmem_tensor.value, tma_desc_ptr.value, loc=loc, ip=ip + ) + + +@dsl_user_op +def fence_tma_desc_acquire( + tma_desc_ptr: Pointer, + *, + loc=None, + ip=None, +) -> None: + """ + See the `PTX documentation `__. + """ + tma_desc_ptr_i64 = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [tma_desc_ptr_i64], + "fence.proxy.tensormap::generic.acquire.gpu [$0], 128;", + "l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cp_fence_tma_desc_release( + tma_desc_global_ptr: Pointer, + tma_desc_shared_ptr: Pointer, + *, + loc=None, + ip=None, +) -> None: + """ + See the `PTX documentation `__. + """ + tma_desc_global_ptr_i64 = tma_desc_global_ptr.toint(loc=loc, ip=ip).ir_value() + tma_desc_shared_ptr_i32 = tma_desc_shared_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [tma_desc_global_ptr_i64, tma_desc_shared_ptr_i32], + "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [$0], [$1], 128;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def fence_tma_desc_release(*, loc=None, ip=None) -> None: + """ + See the `PTX documentation `__. + """ + llvm.inline_asm( + None, + [], + "fence.proxy.tensormap::generic.release.gpu;", + "", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4aa0dbb207dfad2832ddf7a80504c7cf591ff1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Optional, Tuple, Type, Union + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +from .. import core +from ..typing import Shape, Layout, Tensor, Numeric, NumericMeta +from ...impl_utils import check_type_in +from .cpasync.copy import ( + CopyBulkTensorTileG2SOp, + CopyBulkTensorTileG2SNonExecTrait, + CopyBulkTensorTileG2SMulticastOp, + CopyBulkTensorTileG2SMulticastNonExecTrait, +) + + +#################################################################################################### +# +# TMA creation helpers for tcgen05 MMAs +# +#################################################################################################### + + +@dsl_user_op +def make_tiled_tma_atom_A( + op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], + gmem_tensor: Tensor, + smem_layout: Union[Layout, core.ComposedLayout], + mma_tiler_mnk: Shape, + tiled_mma: core.TiledMma, + cluster_shape_vmnk: Shape, + *, + internal_type: Optional[Type[Numeric]] = None, + loc=None, + ip=None, +) -> Tuple[core.CopyAtom, Tensor]: + """ + Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation + accounting for the MK projections of the TiledMMA for A tensor loads. + + Given + + - a GMEM tensor + - a SMEM layout + - a MMA Tiler + - a TiledMma + - a Cluster-level shape + + this function figures out the bulk tensor asynchronous copy instruction to use with the maximum + "TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided + layout and consistent with the provided Tiler & tiled_mma (considering the M-mode & K-mode). + The Cluster-level shape is used to determine the multicast factor across the N-mode for A tensor loads. + + This function returns two results: + + 1. the Copy Atom + 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates + that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the + associated layout can output coordinates. Otherwise, TMA tensors can be partitioned + similarly to any other CuTe tensors using the algebra. + + :param op: The Copy Operation to construct an Atom for + :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp] + :param gmem_tensor: The GMEM tensor to be loaded by this copy atom + :type gmem_tensor: Tensor + :param smem_layout: Shared memory layout to load the tensor into (PDSL) + :type smem_layout: Union[Layout, core.ComposedLayout] + :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions + :type mma_tiler_mnk: Shape + :param tiled_mma: The TiledMMA that will consume the load as operands + :type tiled_mma: core.TiledMma + :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions + :type cluster_shape_vmnk: Shape + :param internal_type: An optional parameter for the internal data type to when element + type does not match the copy type + :type internal_type: Type[Numeric] + :return: A copy atom for this operation and the associated TMA coord tensor + :rtype: Tuple[core.CopyAtom, Tensor] + + """ + + if internal_type is not None: + if not isinstance(internal_type, NumericMeta): + raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") + internal_type = internal_type.mlir_type + check_type_in( + op, + [CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], + "op", + "make_tiled_tma_atom_A", + ) + + ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) + mma_tiler_mk = (mma_tiler_mnk[0], *mma_tiler_mnk[2:]) + g_tile = core.composition(ident, mma_tiler_mk, loc=loc, ip=ip) + cta_v_map = tiled_mma._thrfrg_A(g_tile) + cta_v_map = core.get(cta_v_map, mode=[1]) + cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile))) + + if isinstance(op, CopyBulkTensorTileG2SOp): + num_multicast = 1 + else: + assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) + # multicast across the N-mode since those would share the same tile of A + num_multicast = core.size(cluster_shape_vmnk, mode=[2]) + + # res[0] = the IR Value for the non-executable atom instance + # res[1] = the IR Value for the associated TMA tensor + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + num_multicast=num_multicast, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + if isinstance(op, CopyBulkTensorTileG2SOp): + return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] + else: + assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) + return ( + core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + res[1], + ) + + +@dsl_user_op +def make_tiled_tma_atom_B( + op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], + gmem_tensor: Tensor, + smem_layout: Union[Layout, core.ComposedLayout], + mma_tiler_mnk: Shape, + tiled_mma: core.TiledMma, + cluster_shape_vmnk: Shape, + *, + internal_type: Optional[Type[Numeric]] = None, + loc=None, + ip=None, +) -> Tuple[core.CopyAtom, Tensor]: + """ + Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation + accounting for the NK projections of the TiledMMA for B tensor loads. + + Given + + - a GMEM tensor + - a SMEM layout + - a MMA Tiler + - a TiledMma + - a Cluster-level shape + + this function figures out the bulk tensor asynchronous copy instruction to use with the maximum + "TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided + layout and consistent with the provided Tiler & tiled_mma (considering the N-mode & K-mode). + The Cluster-level shape is used to determine the multicast factor across the M-mode for B tensor loads. + + This function returns two results: + + 1. the Copy Atom + 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates + that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the + associated layout can output coordinates. Otherwise, TMA tensors can be partitioned + similarly to any other CuTe tensors using the algebra. + + :param op: The Copy Operation to construct an Atom for + :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp] + :param gmem_tensor: The GMEM tensor to be loaded by this copy atom + :type gmem_tensor: Tensor + :param smem_layout: Shared memory layout to load the tensor into (PDSL) + :type smem_layout: Union[Layout, core.ComposedLayout] + :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions + :type mma_tiler_mnk: Shape + :param tiled_mma: The TiledMMA that will consume the load as operands + :type tiled_mma: core.TiledMma + :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions + :type cluster_shape_vmnk: Shape + :param internal_type: An optional parameter for the internal data type to when element + type does not match the copy type + :type internal_type: Type[Numeric] + :return: A Copy Atom for this Operation and the associated TMA tensor + :rtype: Tuple[core.CopyAtom, Tensor] + + """ + + if internal_type is not None: + if not isinstance(internal_type, NumericMeta): + raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") + internal_type = internal_type.mlir_type + check_type_in( + op, + [CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], + "op", + "make_tiled_tma_atom_B", + ) + + ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) + mma_tiler_nk = (mma_tiler_mnk[1], *mma_tiler_mnk[2:]) + g_tile = core.composition(ident, mma_tiler_nk, loc=loc, ip=ip) + cta_v_map = tiled_mma._thrfrg_B(g_tile) + cta_v_map = core.get(cta_v_map, mode=[1]) + cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile))) + + if isinstance(op, CopyBulkTensorTileG2SOp): + num_multicast = 1 + else: + assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) + # multicast across the M-mode since those would share the same tile of B + num_multicast = core.size(cluster_shape_vmnk, mode=[1]) + + # res[0] = the IR Value for the non-executable atom instance + # res[1] = the IR Value for the associated TMA tensor + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + num_multicast=num_multicast, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + if isinstance(op, CopyBulkTensorTileG2SOp): + return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] + else: + assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) + return ( + core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + res[1], + ) + + +__all__ = [ + "make_tiled_tma_atom_A", + "make_tiled_tma_atom_B", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2831bec6039b86a2231a5f05bdd3d1b9e0d891b0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .copy import * +from .mma import * +from .helpers import * + +# __all__ is required here for documentation generation +__all__ = [ + # + # copy.py + # + "Repetition", + "Pack", + "Unpack", + "Ld16x64bOp", + "Ld16x128bOp", + "Ld16x256bOp", + "Ld16x32bx2Op", + "Ld32x32bOp", + "St16x64bOp", + "St16x128bOp", + "St16x256bOp", + "St16x32bx2Op", + "St32x32bOp", + # + # mma.py + # + "OperandMajorMode", + "OperandSource", + "CtaGroup", + "Field", + "MmaTF32Op", + "MmaF16BF16Op", + "MmaI8Op", + "MmaFP8Op", + "MmaMXF8Op", + "MmaMXF4Op", + "MmaMXF4NVF4Op", + "SmemLayoutAtomKind", + # + # helpers.py + # + "make_smem_layout_atom", + "tile_to_mma_shape", + "commit", + "is_tmem_load", + "is_tmem_store", + "get_tmem_copy_properties", + "find_tmem_tensor_col_offset", + "make_tmem_copy", + "make_s2t_copy", + "get_s2t_smem_desc_tensor", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py new file mode 100644 index 0000000000000000000000000000000000000000..df954b09d5bcd30321df0dd65a9955fd30a0e811 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py @@ -0,0 +1,663 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from dataclasses import dataclass +from typing import Type + +from cutlass.cutlass_dsl import CuTeDSL + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ..common import OpError +from ...core import CopyOp, Trait +from ...typing import Numeric + +from .mma import CtaGroup + + +class Repetition(enum.Enum): + """ + An enumeration for the number of repetitions of a given TMEM copy within the instruction. + """ + + x1 = 1 + x2 = 2 + x4 = 4 + x8 = 8 + x16 = 16 + x32 = 32 + x64 = 64 + x128 = 128 + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + @classmethod + def _missing_(cls, value): + if isinstance(value, int): + if value == 1: + return Repetition.x1 + elif value == 2: + return Repetition.x2 + elif value == 8: + return Repetition.x8 + elif value == 16: + return Repetition.x16 + elif value == 32: + return Repetition.x32 + elif value == 64: + return Repetition.x64 + elif value == 128: + return Repetition.x128 + + +class Pack(enum.Enum): + """ + An enumeration for the possible packing patterns for TMEM to RMEM copies. + """ + + NONE = enum.auto() + PACK_16b_IN_32b = enum.auto() + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + +class Unpack(enum.Enum): + """ + An enumeration for the possible unpacking patterns for RMEM to TMEM copies. + """ + + NONE = enum.auto() + UNPACK_32b_IN_16b = enum.auto() + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + +@dataclass(frozen=True) +class _LdBase(CopyOp): + repeat: Repetition = Repetition.x1 + pack: Pack = Pack.NONE + + admissible_archs = [ + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + if not isinstance(self.repeat, Repetition): + raise OpError( + self, + "expects the 'repeat' Op parameter to be a tcgen05.Repetition instance", + ) + if not isinstance(self.pack, Pack): + raise OpError( + self, + "expects the 'pack' Op parameter to be a tcgen05.Pack instance", + ) + + def __str__(self) -> str: + res = ( + f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" + + f"\n number of repetitions = {self.repeat.value}" + ) + if self.pack == Pack.PACK_16b_IN_32b: + res += f"\n with 2x 16-bit to 32b packing" + return res + + +@dataclass(frozen=True) +class Ld16x64bOp(_LdBase): + """ + 16x64b TMEM load Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x64b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Ld16x64bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( + copy_internal_type.mlir_type, + 16, + 64, + self.repeat.value, + ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, + ) + return Ld16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Ld16x64bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Ld16x128bOp(_LdBase): + """ + 16x128b TMEM load Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x128b`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.repeat == Repetition.x128: + raise OpError( + self, + "x128 repetition is not supported", + suggestion="choose one of x1, x2, x4, x8, x16, x32, x64", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Ld16x128bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( + copy_internal_type.mlir_type, + 16, + 128, + self.repeat.value, + ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, + ) + return Ld16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Ld16x128bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Ld16x256bOp(_LdBase): + """ + 16x256b TMEM load Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x256b`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.repeat in (Repetition.x128, Repetition.x64): + raise OpError( + self, + "x64 and x128 repetition is not supported", + suggestion="choose one of x1, x2, x4, x8, x16, x32", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Ld16x256bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( + copy_internal_type.mlir_type, + 16, + 256, + self.repeat.value, + ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, + ) + return Ld16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Ld16x256bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Ld16x32bx2Op(_LdBase): + """ + 16x32bx2 TMEM load Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x32bx2`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Ld16x32bx2Trait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( + copy_internal_type.mlir_type, + 16, + 32, + self.repeat.value, + ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, + ) + return Ld16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Ld16x32bx2Trait(Trait): + pass + + +@dataclass(frozen=True) +class Ld32x32bOp(_LdBase): + """ + 32x32b TMEM load Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.32x32`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Ld32x32bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( + copy_internal_type.mlir_type, + 32, + 32, + self.repeat.value, + ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, + ) + return Ld32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Ld32x32bTrait(Trait): + pass + + +@dataclass(frozen=True) +class _StBase(CopyOp): + repeat: Repetition + unpack: Unpack = Unpack.NONE + + admissible_archs = [ + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + if not isinstance(self.repeat, Repetition): + raise OpError( + self, + "expects the 'repeat' Op parameter to be a tcgen05.Repetition instance", + ) + if not isinstance(self.unpack, Unpack): + raise OpError( + self, + "expects the 'pack' Op parameter to be a tcgen05.Unpack instance", + ) + + def __str__(self) -> str: + res = ( + f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" + + f"\n number of repetitions = {self.repeat.value}" + ) + if self.unpack == Unpack.UNPACK_32b_IN_16b: + res += f"\n with 32-bit to 2x 16b unpacking" + return res + + +@dataclass(frozen=True) +class St16x64bOp(_StBase): + """ + 16x64b TMEM store Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x64`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "St16x64bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( + copy_internal_type.mlir_type, + 16, + 64, + self.repeat.value, + ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, + ) + return St16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class St16x64bTrait(Trait): + pass + + +@dataclass(frozen=True) +class St16x128bOp(_StBase): + """ + 16x128b TMEM store Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x128`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.repeat == Repetition.x128: + raise OpError( + self, + "x128 repetition is not supported", + suggestion="choose one of x1, x2, x4, x8, x16, x32, x64", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "St16x128bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( + copy_internal_type.mlir_type, + 16, + 128, + self.repeat.value, + ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, + ) + return St16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class St16x128bTrait(Trait): + pass + + +@dataclass(frozen=True) +class St16x256bOp(_StBase): + """ + 16x256b TMEM store Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x256`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.repeat in (Repetition.x128, Repetition.x64): + raise OpError( + self, + "x64 and x128 repetition is not supported", + suggestion="choose one of x1, x2, x4, x8, x16, x32", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "St16x256bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( + copy_internal_type.mlir_type, + 16, + 256, + self.repeat.value, + ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, + ) + return St16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class St16x256bTrait(Trait): + pass + + +@dataclass(frozen=True) +class St16x32bx2Op(_StBase): + """ + 16x32x2b TMEM store Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x32x2`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "St16x32bx2Trait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( + copy_internal_type.mlir_type, + 16, + 32, + self.repeat.value, + ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, + ) + return St16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class St16x32bx2Trait(Trait): + pass + + +@dataclass(frozen=True) +class St32x32bOp(_StBase): + """ + 32x32b TMEM store Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.32x32`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "St32x32bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( + copy_internal_type.mlir_type, + 32, + 32, + self.repeat.value, + ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, + ) + return St32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class St32x32bTrait(Trait): + pass + + +@dataclass(frozen=True) +class _S2TCopyBase(CopyOp): + cta_group: CtaGroup + + admissible_archs = [ + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, + "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", + ) + + def __str__(self) -> str: + res = ( + f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" + + f"\n CTA group = {self.cta_group}" + ) + + return res + + +@dataclass(frozen=True) +class Cp128x256bOp(_S2TCopyBase): + """ + 128x256b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.128x256b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp128x256bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 128, + 256, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.none, + ) + return Cp128x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp128x256bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp128x128bOp(_S2TCopyBase): + """ + 128x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.128x128b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp128x128bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 128, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.none, + ) + return Cp128x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp128x128bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp4x256bOp(_S2TCopyBase): + """ + 4x256b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.4x256b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp4x256bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 4, + 256, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.none, + ) + return Cp4x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp4x256bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp4x32x128bOp(_S2TCopyBase): + """ + 32x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.32x128b`` qualifier with ``warpx4`` broadcast qualifier enabled. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp4x32x128bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 32, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.x4, + ) + return Cp4x32x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp4x32x128bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp2x64x128b0213Op(_S2TCopyBase): + """ + 64x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::02_13`` broadcast qualifier enabled. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp2x64x128b0213Trait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 64, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.lw_0213, + ) + return Cp2x64x128b0213Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp2x64x128b0213Trait(Trait): + pass + + +@dataclass(frozen=True) +class Cp2x64x128b0123Op(_S2TCopyBase): + """ + 64x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::01_23`` broadcast qualifier enabled. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp2x64x128b0123Trait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 64, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.lw_0123, + ) + return Cp2x64x128b0123Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp2x64x128b0123Trait(Trait): + pass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad27e62962e874da6707ac8a36863d5ed8f98a4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py @@ -0,0 +1,328 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import overload, Type, Tuple, Union + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects import nvvm + +from ...typing import ( + Shape, + IntTuple, + Layout, + Tensor, + Int, + Numeric, + NumericMeta, + Int16, + Int32, +) +from ... import core +from .mma import SmemLayoutAtomKind, CtaGroup +from .copy import ( + Pack, + Unpack, + Ld16x64bOp, + Ld16x128bOp, + Ld16x256bOp, + Ld16x32bx2Op, + Ld32x32bOp, + St16x64bOp, + St16x128bOp, + St16x256bOp, + St16x32bx2Op, + St32x32bOp, +) + + +#################################################################################################### +# +# Helper functions for MMA +# +#################################################################################################### + + +@dsl_user_op +def make_smem_layout_atom( + kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None +) -> core.ComposedLayout: + """ + Makes a SMEM layout Atom. + + This function creates a composed layout in unit of elements consistent with the requested layout + Atom kind and element data type. + + :param kind: The kind of layout Atom + :type kind: SmemLayoutAtomKind + :param element_type: The element data type to construct the layout for + :type element_type: Type[Numeric] + :return: The SMEM layout atom + :rtype: core.ComposedLayout + """ + if not isinstance(element_type, NumericMeta): + raise TypeError(f"element_type must be a Numeric, but got {element_type}") + + if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER): + num_contiguous_bits = 128 + sw = core.make_swizzle(0, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32): + num_contiguous_bits = 256 + sw = core.make_swizzle(1, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64): + num_contiguous_bits = 512 + sw = core.make_swizzle(2, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128): + num_contiguous_bits = 1024 + sw = core.make_swizzle(3, 4, 3) + elif kind == SmemLayoutAtomKind.MN_SW128_32B: + num_contiguous_bits = 1024 + sw = core.make_swizzle(2, 5, 2) + else: + raise ValueError("unrecognized SMEM layout atom kind") + num_contiguous_elems = num_contiguous_bits // element_type.width + + if kind in ( + SmemLayoutAtomKind.MN_INTER, + SmemLayoutAtomKind.MN_SW32, + SmemLayoutAtomKind.MN_SW64, + SmemLayoutAtomKind.MN_SW128, + SmemLayoutAtomKind.MN_SW128_32B, + ): + # M/N-major layout + return core.make_composed_layout( + sw, + 0, + core.make_layout( + (num_contiguous_elems, 8), stride=(1, num_contiguous_elems) + ), + loc=loc, + ip=ip, + ) + else: + # K-major layout + return core.make_composed_layout( + sw, + 0, + core.make_layout( + (8, num_contiguous_elems), stride=(num_contiguous_elems, 1) + ), + loc=loc, + ip=ip, + ) + + +@overload +def tile_to_mma_shape( + atom: Layout, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None +) -> Layout: ... + + +@overload +def tile_to_mma_shape( + atom: core.ComposedLayout, + mma_tile_shape: Shape, + order: IntTuple = None, + *, + loc=None, + ip=None, +) -> core.ComposedLayout: ... + + +@dsl_user_op +def tile_to_mma_shape( + atom, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None +): + """ + Tiles a layout to an MMA shape. + """ + # Default order is colexicographical + if order is None: + order = tuple(range(core.rank(mma_tile_shape) - 1)) + if core.rank(order) != core.rank(mma_tile_shape) - 1: + raise ValueError( + f"rank(order)={core.rank(order)} must be equal to " + f"rank(mma_tile_shape)-1={core.rank(mma_tile_shape)-1}" + ) + order_val = core._pack_int_tuple(order, loc=loc, ip=ip) + mma_tile_shape_val = core._pack_shape(mma_tile_shape, loc=loc, ip=ip) + + if not ( + core.is_static(atom) + and core.is_static(mma_tile_shape_val) + and core.is_static(order_val) + ): + raise ValueError("tile_to_mma_shape only supports static inputs") + + res_ty = _cute_nvgpu_ir.tile_to_mma_shape(atom, mma_tile_shape_val, order_val) + return _cute_ir.static(res_ty, loc=loc, ip=ip) + + +@dsl_user_op +def commit( + mbar_ptr: core.Pointer, + mask=None, + cta_group: CtaGroup = CtaGroup.ONE, + *, + loc=None, + ip=None, +) -> None: + """ + Perform an arrive operation on a mbarrier upon completion of previous MMA operations. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param mask: An optional multicast mask for the CTAs in the cluster to signal arrival to + :type mask: Int + """ + if cta_group == CtaGroup.ONE: + group = nvvm.Tcgen05GroupKind.CTA_1 + else: + assert cta_group == CtaGroup.TWO + group = nvvm.Tcgen05GroupKind.CTA_2 + + mbar_ptr = mbar_ptr.llvm_ptr + if mask is not None: + mask = Int16(mask).ir_value(loc=loc, ip=ip) + nvvm.tcgen05_commit_arrive( + mbar_ptr, multicast_mask=mask, group=group, loc=loc, ip=ip + ) + else: + nvvm.tcgen05_commit_arrive(mbar_ptr, group=group, loc=loc, ip=ip) + return + + +#################################################################################################### +# +# Helper functions for Copies +# +#################################################################################################### + + +def is_tmem_load(atom: core.CopyAtom) -> bool: + """ + Returns whether a CopyAtom instance is a TMEM load. + """ + return isinstance( + atom.op, + ( + Ld16x64bOp, + Ld16x128bOp, + Ld16x256bOp, + Ld16x32bx2Op, + Ld32x32bOp, + ), + ) + + +def is_tmem_store(atom: core.CopyAtom) -> bool: + """ + Returns whether a CopyAtom instance is a TMEM store. + """ + return isinstance( + atom.op, + ( + St16x64bOp, + St16x128bOp, + St16x256bOp, + St16x32bx2Op, + St32x32bOp, + ), + ) + + +def get_tmem_copy_properties( + atom: core.CopyAtom, +) -> Tuple[int, int, int, Union[Pack, Unpack]]: + """ + Returns the properties of a TMEM copy atom (number of data paths, bits, repetitions, + and whether packing/unpacking is used). + """ + if isinstance(atom.op, (Ld16x64bOp, St16x64bOp)): + num_dp, num_bits = 16, 64 + elif isinstance(atom.op, (Ld16x128bOp, St16x128bOp)): + num_dp, num_bits = 16, 128 + elif isinstance(atom.op, (Ld16x256bOp, St16x256bOp)): + num_dp, num_bits = 16, 256 + elif isinstance(atom.op, (Ld16x32bx2Op, St16x32bx2Op)): + num_dp, num_bits = 16, 32 + elif isinstance(atom.op, (Ld32x32bOp, St32x32bOp)): + num_dp, num_bits = 32, 32 + else: + raise ValueError(f"expects 'atom' to be a TMEM copy, but got {atom}") + if is_tmem_load(atom): + return num_dp, num_bits, atom.op.repeat.value, atom.op.pack + else: + assert is_tmem_store(atom), "atom must be a TMEM store" + return num_dp, num_bits, atom.op.repeat.value, atom.op.unpack + + +@dsl_user_op +def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> Int: + """ + Computes the TMEM column offset given a TMEM tensor. + + :param tmem_tensor: The TMEM tensor to use to compute the columns offset + :type tmem_tensor: Tensor + :return: The columns offset + :rtype: Int + """ + tmem_col_mask = 0x0000FFFF + offset = ( + core.cosize(core.recast_tensor(tmem_tensor, Int32).layout, loc=loc, ip=ip) + & tmem_col_mask + ) + if isinstance(offset, int): + return offset + return Int32(offset, loc=loc, ip=ip) + + +@dsl_user_op +def make_tmem_copy( + atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None +) -> core.TiledCopy: + """ + Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. + """ + tiled_copy_val = _cute_nvgpu_ir.atom_make_tmem_copy( + atom._trait.value, tmem_tensor.value, loc=loc, ip=ip + ) + new_trait = type(atom._trait)(tiled_copy_val) + return core.TiledCopy(atom.op, new_trait) + + +@dsl_user_op +def make_s2t_copy( + atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None +) -> core.TiledCopy: + """ + Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. + """ + tiled_copy_val = _cute_nvgpu_ir.atom_make_s2t_copy( + atom._trait.value, tmem_tensor.value, loc=loc, ip=ip + ) + new_trait = type(atom._trait)(tiled_copy_val) + return core.TiledCopy(atom.op, new_trait) + + +@dsl_user_op +def get_s2t_smem_desc_tensor( + atom: core.CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None +) -> Tensor: + """ + Returns the SMEM descriptor tensor from a S2T copy atom and a SMEM tensor. + """ + smem_desc_tensor = _cute_nvgpu_ir.atom_get_copy_s2t_smem_desc_view( + atom._trait.value, smem_tensor.value, loc=loc, ip=ip + ) + return smem_desc_tensor diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py new file mode 100644 index 0000000000000000000000000000000000000000..3a938523e130cf551c205669164e15e8bbd29132 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py @@ -0,0 +1,1041 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from dataclasses import dataclass +from typing import Type + +from cutlass.cutlass_dsl import CuTeDSL, T + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ..common import OpError +from ... import core +from ...core import Trait, _pack_shape, rank, depth, _Tensor +from ...typing import ( + Shape, + Float4E2M1FN, + Float8E8M0FNU, + Float8E5M2, + Float8E4M3FN, + Float16, + BFloat16, + Float32, + TFloat32, + Boolean, + Int8, + Uint8, + Int32, + Numeric, + AddressSpace, + Pointer, +) + + +#################################################################################################### +# +# MMA Ops and Traits +# +#################################################################################################### + + +class OperandMajorMode(enum.Enum): + """ + An enumeration for the majorness of the input operands of the MMA. + """ + + MN = _cute_ir.MajorMode.mn + K = _cute_ir.MajorMode.k + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + @classmethod + def _missing_(cls, value): + if isinstance(value, str): + value = value.upper() + if value == "MN": + return OperandMajorMode.MN + elif value == "K": + return OperandMajorMode.K + + def _to_ir(self) -> _cute_ir.MajorMode: + return self.value + + +class OperandSource(enum.Enum): + """ + An enumeration for the source memory location of the A input operand of the MMA. + """ + + TMEM = _cute_ir.MmaFragKind.tmem + SMEM = _cute_ir.MmaFragKind.smem_desc + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.MmaFragKind: + return self.value + + +class CtaGroup(enum.Enum): + """ + An enumeration for the ``cta_group`` qualifier of the MMA. + """ + + ONE = 1 + TWO = 2 + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + +class Field(enum.Enum): + """ + An enumeration for the fields of the MMA Atom that can be modified at runtime. + """ + + NEGATE_A = "neg_a" + NEGATE_B = "neg_b" + ACCUMULATE = "accum_c" + SFA = "sf_a" + SFB = "sf_b" + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir_field_name(self) -> str: + return self.value + + +# Base class for all tcgen05 MMA Ops with syntax `tcgen05.mma.cta_group.kind` used to factor out some internal code +@dataclass(frozen=True) +class MmaOp(core.MmaOp): + a_dtype: Type[Numeric] + b_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + shape_mnk: Shape + cta_group: CtaGroup + a_src: OperandSource + a_major_mode: OperandMajorMode + b_major_mode: OperandMajorMode + + admissible_archs = [ + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + # Verify arch + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, + "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", + ) + if not isinstance(self.a_src, OperandSource): + raise OpError( + self, + "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance", + ) + if not isinstance(self.a_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + if not isinstance(self.b_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + # Verify the instruction shape + if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + raise OpError( + self, + f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " + f"but got {self.shape_mnk}", + ) + m, n = self.shape_mnk[0], self.shape_mnk[1] + if self.cta_group == CtaGroup.ONE: + if m not in [64, 128]: + raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}") + if m == 64: + if (n < 8) or (n > 256) or (n % 8 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}", + ) + elif m == 128: + if (n < 16) or (n > 256) or (n % 16 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 16 == 0, but got {n}", + ) + else: + if m not in [128, 256]: + raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}") + if (n < 32) or (n > 256) or (n % 32 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}", + ) + + def __str__(self) -> str: + return ( + self.__class__.descriptive_name # type: ignore + + f"\n A data type = {self.a_dtype}" + + f"\n B data type = {self.b_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n CTA group = {self.cta_group}" + + f"\n A source location = {self.a_src}" + + f"\n A major mode = {self.a_major_mode}" + + f"\n B major mode = {self.b_major_mode}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + ) + + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand A, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand B, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + +class MmaTrait(Trait): + admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B] + + def set(self, field, value, *, loc=None, ip=None) -> None: + if field not in self.admissible_fields: + raise ValueError( + f"expects field to be one of {self.admissible_fields}, but got {field}" + ) + field_name = f"#cute_nvgpu.atom_mma_field_sm100<{field._to_ir_field_name()}>" + attr = ir.Attribute.parse(field_name) + self.value = _cute_nvgpu_ir.atom_set_value( + self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + + +# Base class for all tcgen05 BlockScaled MMA Ops with syntax `tcgen05.mma.cta_group.kind.block_scale` used to factor out some internal code +@dataclass(frozen=True) +class BlockScaledMmaOp(core.MmaOp): + a_dtype: Type[Numeric] + b_dtype: Type[Numeric] + acc_dtype: Float32 + sf_dtype: Type[Numeric] + sf_vec_size: int + shape_mnk: Shape + cta_group: CtaGroup + a_src: OperandSource + a_major_mode: OperandMajorMode + b_major_mode: OperandMajorMode + + admissible_archs = [ + "sm_100a", + ] + + def __post_init__(self) -> None: + # Verify arch + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, + "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", + ) + if not isinstance(self.a_src, OperandSource): + raise OpError( + self, + "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance", + ) + if not isinstance(self.a_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + if not isinstance(self.b_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + # Verify the instruction shape + if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + raise OpError( + self, + f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " + f"but got {self.shape_mnk}", + ) + m, n = self.shape_mnk[0], self.shape_mnk[1] + if self.cta_group == CtaGroup.ONE: + if m != 128: + raise OpError(self, f"expects the M-mode to be 128, but got {m}") + + if (n < 8) or (n > 256) or (n % 8 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}", + ) + else: + if m not in [128, 256]: + raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}") + if (n < 16) or (n > 256) or (n % 16 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}", + ) + if self.sf_vec_size not in [16, 32]: + raise OpError( + self, + f"expects the scale factor vector size to be 16 or 32, but got {self.sf_vec_size}", + ) + + def __str__(self) -> str: + return ( + self.__class__.descriptive_name # type: ignore + + f"\n A data type = {self.a_dtype}" + + f"\n B data type = {self.b_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n Scale factor data type = {self.sf_dtype}" + + f"\n Scale factor vector size = {self.sf_vec_size}" + + f"\n CTA group = {self.cta_group}" + + f"\n A source location = {self.a_src}" + + f"\n A major mode = {self.a_major_mode}" + + f"\n B major mode = {self.b_major_mode}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + ) + + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand A, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand B, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + +class BlockScaledMmaTraits(Trait): + admissible_fields = [ + Field.ACCUMULATE, + Field.NEGATE_A, + Field.NEGATE_B, + Field.SFA, + Field.SFB, + ] + + def set(self, field, value, *, loc=None, ip=None) -> None: + if field not in self.admissible_fields: + raise ValueError( + f"expects field to be one of {self.admissible_fields}, but got {field}" + ) + if field in [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]: + value = Boolean(value).ir_value(loc=loc, ip=ip) + elif field in [Field.SFA, Field.SFB]: + if not isinstance(value, Pointer): + raise ValueError( + f"expects value to be a pointer for {field}, but got {type(value).__name__}" + ) + value = value.value + + field_name = f"#cute_nvgpu.atom_mma_field_sm100_block_scaled<{field._to_ir_field_name()}>" + attr = ir.Attribute.parse(field_name) + self.value = _cute_nvgpu_ir.atom_set_value( + self.value, attr, value, loc=loc, ip=ip + ) + + +# +# TF32 MMA +# + + +@dataclass(frozen=True) +class MmaTF32Op(MmaOp): + """ + TF32 tcgen05 MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::tf32`` qualifier. + """ + + descriptive_name = "tcgen05 TF32 MMA Operation" + + def __init__( + self, + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + TFloat32, + TFloat32, + Float32, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Verify the instruction shape + instruction_k = 8 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaTF32Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + 0, + ) + return MmaTF32Trait( + _cute_nvgpu_ir.make_sm100_mma( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +class MmaTF32Trait(MmaTrait): + pass + + +# +# F16/BF16 MMA +# + + +@dataclass(frozen=True) +class MmaF16BF16Op(MmaOp): + """ + F16/BF16 tcgen05 MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::f16`` qualifier. + """ + + descriptive_name = "tcgen05 F16/BF16 MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + ab_dtype, + ab_dtype, + acc_dtype, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Float16, BFloat16]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Accumulator data type verification + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + # Instruction shape verification + instruction_k = 16 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + 0, + ) + return MmaF16BF16Trait( + _cute_nvgpu_ir.make_sm100_mma( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +class MmaF16BF16Trait(MmaTrait): + pass + + +# +# I8 MMA +# + + +@dataclass(frozen=True) +class MmaI8Op(MmaOp): + """ + I8 tcgen05 MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::i8`` qualifier. + """ + + descriptive_name = "tcgen05 I8 MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + ab_dtype, + ab_dtype, + Int32, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Int8, Uint8]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Int8 or Uint8", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Instruction shape verification + instruction_k = 32 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + (T.si8() if self.a_dtype.signed else T.ui8()), + (T.si8() if self.b_dtype.signed else T.ui8()), + T.si32(), + self.a_src._to_ir(), + 0, + ) + return MmaI8Trait( + _cute_nvgpu_ir.make_sm100_mma( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +class MmaI8Trait(MmaTrait): + pass + + +# +# F8F6F4 MMA +# + + +@dataclass(frozen=True) +class MmaFP8Op(MmaOp): + """ + F8 tcgen05 MMA Operation. + + See the `PTX documentation `__. + """ + + descriptive_name = "tcgen05 F8 MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + + super().__init__( + ab_dtype, + ab_dtype, + acc_dtype, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Accumulator data type verification + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + # Instruction shape verification + instruction_k = 32 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaFP8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + 0, + ) + return MmaFP8Trait( + _cute_nvgpu_ir.make_sm100_mma( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +class MmaFP8Trait(MmaTrait): + pass + + +# +# MXF8F6F4 MMA +# + + +@dataclass(frozen=True) +class MmaMXF8Op(BlockScaledMmaOp): + """ + MXF8 tcgen05 BlockScaled MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::mxf8f6f4`` qualifier. + """ + + descriptive_name = "tcgen05 MXF8 BlockScaled MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + ab_dtype, + ab_dtype, + Float32, + Float8E8M0FNU, + 32, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Instruction shape verification + instruction_k = 32 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_dtype.mlir_type, + self.a_src._to_ir(), + self.sf_vec_size, + ) + return MmaMXF8Trait( + _cute_nvgpu_ir.make_sm100_mma_bs( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + loc=loc, + ip=ip, + ) + ) + + +class MmaMXF8Trait(BlockScaledMmaTraits): + pass + + +# +# MXF4 MMA +# + + +@dataclass(frozen=True) +class MmaMXF4Op(BlockScaledMmaOp): + """ + MXF4 tcgen05 BlockScaled MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::mxf4`` qualifier. + """ + + descriptive_name = "tcgen05 MXF4 BlockScaled MMA Operation" + + def __init__( + self, + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + ) -> None: + super().__init__( + Float4E2M1FN, + Float4E2M1FN, + Float32, + Float8E8M0FNU, + 32, + instruction_shape, + cta_group, + a_src, + OperandMajorMode.K, + OperandMajorMode.K, + ) + self._verify() + + def _verify(self) -> None: + # Instruction shape verification + instruction_k = 64 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_dtype.mlir_type, + self.a_src._to_ir(), + self.sf_vec_size, + ) + return MmaMXF4Trait( + _cute_nvgpu_ir.make_sm100_mma_bs( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + loc=loc, + ip=ip, + ) + ) + + +class MmaMXF4Trait(BlockScaledMmaTraits): + pass + + +# +# MXF4NVF4 MMA +# + + +@dataclass(frozen=True) +class MmaMXF4NVF4Op(BlockScaledMmaOp): + """ + MXF4NVF4 tcgen05 BlockScaled MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::mxf4nvf4`` qualifier. + """ + + descriptive_name = "tcgen05 MXF4NVF4 BlockScaled MMA Operation" + + def __init__( + self, + sf_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + ) -> None: + super().__init__( + Float4E2M1FN, + Float4E2M1FN, + Float32, + sf_dtype, + 16, + instruction_shape, + cta_group, + a_src, + OperandMajorMode.K, + OperandMajorMode.K, + ) + self._verify() + + def _verify(self) -> None: + # Scale Factor data type verification + if self.sf_dtype not in [Float8E8M0FNU, Float8E4M3FN]: + raise OpError( + self, + "expects the 'sf_dtype' Op parameter to be one of Float8E8M0FNU", + ) + # Instruction shape verification + instruction_k = 64 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_dtype.mlir_type, + self.a_src._to_ir(), + self.sf_vec_size, + ) + return MmaMXF4NVF4Trait( + _cute_nvgpu_ir.make_sm100_mma_bs( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + loc=loc, + ip=ip, + ) + ) + + +class MmaMXF4NVF4Trait(BlockScaledMmaTraits): + pass + +#################################################################################################### +# +# SMEM layout atoms +# +#################################################################################################### + + +class SmemLayoutAtomKind(enum.Enum): + """ + Enum class for the kinds of SMEM layout atoms for SM100. + + Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can be + used to construct an SMEM layout using blocked product for operand A or B such that the + resulting layout is legal for both TMA and UMMA. + + Note that there are other ways of creating legal layouts for operand A and B. + """ + + MN_INTER = enum.auto() + MN_SW32 = enum.auto() + MN_SW64 = enum.auto() + MN_SW128 = enum.auto() + MN_SW128_32B = enum.auto() + K_INTER = enum.auto() + K_SW32 = enum.auto() + K_SW64 = enum.auto() + K_SW128 = enum.auto() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2b3f7cf5b0698752d7ea6c450782f17a3fee797 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .copy import * +from .mma import * + + +# __all__ is required here for documentation generation +__all__ = [ + # mma.py + "MmaF16BF16Op", + # copy.py + "LdMatrix8x8x16bOp", + "LdMatrix16x16x8bOp", + "StMatrix8x8x16bOp", + "StMatrix16x8x8bOp", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ad4ca8f0e2dd05b6e779eaedec0b69cd47decf --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from dataclasses import dataclass +from typing import Type + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ..common import OpError +from ...core import CopyOp, Trait, _pack_shape +from ...typing import Numeric + + +@dataclass(frozen=True) +class BaseOp(CopyOp): + transpose: bool = False + num_matrices: int = 1 + + def __post_init__(self) -> None: + if not isinstance(self.transpose, bool): + raise OpError( + self, + "expects the 'transpose' Op parameter to be a bool instance", + ) + + def __str__(self) -> str: + res = ( + f"{self.__class__.__name__[:-2]} Copy Operation" + + f"\n number of matrices = {self.num_matrices}" + ) + if self.transpose: + res += f"\n transposed" + return res + + +@dataclass(frozen=True) +class LdMatrix8x8x16bOp(BaseOp): + """ + 8x8 ``ldmatrix`` Operation. + + See the `PTX documentation `__. + This operation corresponds to the ``.m8n8`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.num_matrices not in [1, 2, 4]: + raise OpError( + self, + "expects the 'num_matrices' Op parameter to be one of [1,2,4]", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "LdMatrix8x8x16bTrait": + mode = _pack_shape((8, 8), loc=loc, ip=ip) + ty = _cute_nvgpu_ir.CopyAtomLdsmType.get( + copy_internal_type.mlir_type, + mode.type.attribute, + _cute_nvgpu_ir.LdsmSzPattern.u16, + self.num_matrices, + ir.UnitAttr.get() if self.transpose else None, + ) + return LdMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class LdMatrix8x8x16bTrait(Trait): + pass + + +@dataclass(frozen=True) +class LdMatrix16x16x8bOp(BaseOp): + """ + 16x16 8-bit ``ldmatrix`` Operation. + + See the `PTX documentation `__. + This operation corresponds to the ``.m16n16`` and the ``.b16`` qualifiers. + """ + + def __init__(self, num_matrices: int) -> None: + super().__init__(transpose=True, num_matrices=num_matrices) + self._verify() + + def _verify(self): + assert self.transpose, "transpose must be True" + if self.num_matrices not in [1, 2]: + raise OpError( + self, + "expects the 'num_matrices' Op parameter to be one of [1,2]", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "LdMatrix16x16x8bTrait": + mode = _pack_shape((16, 16), loc=loc, ip=ip) + ty = _cute_nvgpu_ir.CopyAtomLdsmType.get( + copy_internal_type.mlir_type, + mode.type.attribute, + _cute_nvgpu_ir.LdsmSzPattern.u8, + self.num_matrices, + ir.UnitAttr.get(), + ) + return LdMatrix16x16x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class LdMatrix16x16x8bTrait(Trait): + pass + + +@dataclass(frozen=True) +class StMatrix8x8x16bOp(BaseOp): + """ + 8x8 ``stmatrix`` Operation. + + See the `PTX documentation `__. + This operation corresponds to the ``m8n8`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.num_matrices not in [1, 2, 4]: + raise OpError( + self, + "expects the 'num_matrices' Op parameter to be one of [1,2,4]", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "StMatrix8x8x16bTrait": + mode = _pack_shape((8, 8), loc=loc, ip=ip) + ty = _cute_nvgpu_ir.CopyAtomStsmType.get( + copy_internal_type.mlir_type, + mode.type.attribute, + self.num_matrices, + ir.UnitAttr.get() if self.transpose else None, + ) + return StMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class StMatrix8x8x16bTrait(Trait): + pass + + +@dataclass(frozen=True) +class StMatrix16x8x8bOp(BaseOp): + """ + 16x8 ``stmatrix`` Operation. + + See the `PTX documentation `__. + This operation corresponds to the ``m16n8`` qualifier. + """ + + def __init__(self, num_matrices: int) -> None: + super().__init__(transpose=True, num_matrices=num_matrices) + self._verify() + + def _verify(self): + if self.num_matrices not in [1, 2, 4]: + assert self.transpose, "transpose must be True" + raise OpError( + self, + "expects the 'num_matrices' Op parameter to be one of [1,2,4]", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "StMatrix16x8x8bTrait": + mode = _pack_shape((16, 8), loc=loc, ip=ip) + ty = _cute_nvgpu_ir.CopyAtomStsmType.get( + copy_internal_type.mlir_type, + mode.type.attribute, + self.num_matrices, + ir.UnitAttr.get(), + ) + return StMatrix16x8x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class StMatrix16x8x8bTrait(Trait): + pass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py new file mode 100644 index 0000000000000000000000000000000000000000..49df213b76f24f23ecfe5a75e36cf17d35aeb98b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from dataclasses import dataclass +from typing import Type + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +from ..common import OpError +from ...core import MmaOp, Trait, _pack_shape, _Tensor +from ...typing import Shape, Float16, BFloat16, Float32, Numeric, AddressSpace + + +@dataclass(frozen=True) +class MmaF16BF16Op(MmaOp): + """ + F16/BF16 tcgen05 MMA Operation. + + See the `PTX documentation `__. + This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands. + """ + + ab_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + shape_mnk: Shape + + def __post_init__(self) -> None: + if self.ab_dtype not in [Float16, BFloat16]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", + ) + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + if (self.ab_dtype == BFloat16) and (self.acc_dtype != Float32): + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16", + ) + if self.shape_mnk not in [(16, 8, 8), (16, 8, 16)]: + raise OpError( + self, + "expects the 'shape_mnk' Op parameter to be one of (16,8,8) or (16,8,16)", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM80Type.get( + shape_mnk.type.attribute, + self.ab_dtype.mlir_type, + self.ab_dtype.mlir_type, + self.acc_dtype.mlir_type, + ) + return MmaF16BF16Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + def __str__(self) -> str: + return ( + "warp-level F16/BF16 MMA Operation" + + f"\n A/B data type = {self.ab_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + ) + + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + pass + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + pass + +class MmaF16BF16Trait(Trait): + pass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49a40165033024c9c9b17acd298a1f8ba055649c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .mma import * +from .helpers import * + +# __all__ is required here for documentation generation +__all__ = [ + # mma.py + "OperandMajorMode", + "OperandSource", + "Field", + "MmaF16BF16Op", + "MmaF8Op", + "SmemLayoutAtomKind", + # helpers.py + "make_smem_layout_atom", + "fence", + "commit_group", + "wait_group", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..f6284134933bec170ecec5eeb0bf9f829ef0dff0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Type + +from cutlass.cutlass_dsl import dsl_user_op + +from cutlass._mlir.dialects import nvvm + +from ...typing import Numeric, NumericMeta +from ... import core +from .mma import SmemLayoutAtomKind + + +@dsl_user_op +def make_smem_layout_atom( + kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None +) -> core.ComposedLayout: + """ + Makes a SMEM layout Atom. + + This function creates a composed layout in unit of elements consistent with the requested layout + Atom kind and element data type. + + :param kind: The kind of layout Atom + :type kind: SmemLayoutAtomKind + :param element_type: The element data type to construct the layout for + :type element_type: Type[Numeric] + :return: The SMEM layout atom + :rtype: core.ComposedLayout + """ + if not isinstance(element_type, NumericMeta): + raise TypeError(f"element_type must be a Numeric, but got {element_type}") + + if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER): + num_contiguous_bits = 128 + sw = core.make_swizzle(0, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32): + num_contiguous_bits = 256 + sw = core.make_swizzle(1, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64): + num_contiguous_bits = 512 + sw = core.make_swizzle(2, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128): + num_contiguous_bits = 1024 + sw = core.make_swizzle(3, 4, 3) + else: + raise ValueError("unrecognized SMEM layout atom kind") + num_contiguous_elems = num_contiguous_bits // element_type.width + + if kind in ( + SmemLayoutAtomKind.MN_INTER, + SmemLayoutAtomKind.MN_SW32, + SmemLayoutAtomKind.MN_SW64, + SmemLayoutAtomKind.MN_SW128, + ): + # M/N-major layout + return core.make_composed_layout( + sw, + 0, + core.make_layout( + (num_contiguous_elems, 8), stride=(1, num_contiguous_elems) + ), + loc=loc, + ip=ip, + ) + else: + # K-major layout + return core.make_composed_layout( + sw, + 0, + core.make_layout( + (8, num_contiguous_elems), stride=(num_contiguous_elems, 1) + ), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def fence(*, loc=None, ip=None) -> None: + """ + See the `PTX documentation `__. + """ + nvvm.wgmma_fence_aligned(loc=None, ip=None) + + +@dsl_user_op +def commit_group(*, loc=None, ip=None) -> None: + """ + See the `PTX documentation `__. + """ + nvvm.wgmma_commit_group_sync_aligned(loc=loc, ip=ip) + + +@dsl_user_op +def wait_group(group, *, loc=None, ip=None) -> None: + """ + See the `PTX documentation `__. + """ + nvvm.wgmma_wait_group_sync_aligned(group, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py new file mode 100644 index 0000000000000000000000000000000000000000..275861f70cc3d6eca932cb263890aaaa4121445f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py @@ -0,0 +1,405 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from dataclasses import dataclass +from typing import Type + +from cutlass.cutlass_dsl import CuTeDSL + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ..common import OpError +from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor +from ...typing import ( + Shape, + Float16, + BFloat16, + Float32, + Boolean, + Float8E5M2, + Float8E4M3FN, + Numeric, + AddressSpace, +) + + +#################################################################################################### +# +# MMA Ops and Traits +# +#################################################################################################### + + +class OperandMajorMode(enum.Enum): + """ + An enumeration for the majorness of the input operands of the MMA. + """ + + MN = _cute_ir.MajorMode.mn + K = _cute_ir.MajorMode.k + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + @classmethod + def _missing_(cls, value): + if isinstance(value, str): + value = value.upper() + if value == "MN": + return OperandMajorMode.MN + elif value == "K": + return OperandMajorMode.K + + def _to_ir(self) -> _cute_ir.MajorMode: + return self.value + + +class OperandSource(enum.Enum): + """ + An enumeration for the source memory location of the A input operand of the MMA. + """ + + RMEM = _cute_ir.MmaFragKind.rmem + SMEM = _cute_ir.MmaFragKind.smem_desc + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.MmaFragKind: + return self.value + + +class Field(enum.Enum): + """ + An enumeration for the fields of the MMA Atom that can be modified at runtime. + """ + + ACCUMULATE = "accum_c" + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir_field_name(self) -> str: + return self.value + + +@dataclass(frozen=True) +class MmaOp(MmaOp): + a_dtype: Type[Numeric] + b_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + shape_mnk: Shape + a_src: OperandSource + a_major_mode: OperandMajorMode + b_major_mode: OperandMajorMode + + admissible_archs = ["sm_90a"] + + def __post_init__(self) -> None: + # Verify arch + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.a_src, OperandSource): + raise OpError( + self, + "expects the 'a_src' Op parameter to be a warpgroup.OperandSource instance", + ) + if not isinstance(self.a_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'a_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance", + ) + if not isinstance(self.b_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'b_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance", + ) + # Verify instruction shape + if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + raise OpError( + self, + f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " + f"but got {self.shape_mnk}", + ) + m, n = self.shape_mnk[0], self.shape_mnk[1] + if m != 64: + raise OpError(self, f"expects the M-mode to be 64, but got {m}") + if (n < 8) or (n > 256) or (n % 8 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0. but got {n}", + ) + + def __str__(self) -> str: + return ( + self.__class__.descriptive_name # type: ignore + + f"\n A data type = {self.a_dtype}" + + f"\n B data type = {self.b_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n A source location = {self.a_src}" + + f"\n A major mode = {self.a_major_mode}" + + f"\n B major mode = {self.b_major_mode}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + ) + + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand A, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand B, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + +class MmaTrait(Trait): + admissible_fields = [Field.ACCUMULATE] + + def set(self, field, value, *, loc=None, ip=None) -> None: + if field not in self.admissible_fields: + raise ValueError( + f"invalid field, must be {Field.ACCUMULATE}, but got {field}" + ) + field_name = f"#cute_nvgpu.atom_mma_field_sm90<{field._to_ir_field_name()}>" + attr = ir.Attribute.parse(field_name) + self.value = _cute_nvgpu_ir.atom_set_value( + self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + + +@dataclass(frozen=True) +class MmaF16BF16Op(MmaOp): + """ + F16/BF16 warpgroup MMA Operation. + + See the `PTX documentation `__. + This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands. + """ + + descriptive_name = "warpgroup F16/BF16 MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + ab_dtype, + ab_dtype, + acc_dtype, + instruction_shape, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Float16, BFloat16]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Accumulator data type verification + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + if (self.a_dtype == BFloat16) and (self.acc_dtype != Float32): + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16", + ) + # Verify the instruction shape + instruction_k = 16 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM90Type.get( + shape_mnk.type.attribute, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + ) + return MmaF16BF16Trait( + _cute_nvgpu_ir.make_sm90_mma( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +class MmaF16BF16Trait(MmaTrait): + pass + + +@dataclass(frozen=True) +class MmaF8Op(MmaOp): + """ + F16/BF16 warpgroup MMA Operation. + + See the `PTX documentation `__. + This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands. + """ + + descriptive_name = "warpgroup F8 MMA Operation" + + def __init__( + self, + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + a_dtype, + b_dtype, + acc_dtype, + instruction_shape, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self): + # Input data type verification + if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: + raise OpError( + self, + "expects the 'a_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", + ) + if self.b_dtype not in [Float8E5M2, Float8E4M3FN]: + raise OpError( + self, + "expects the 'b_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", + ) + # Accumulator data type verification + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + # Verify the instruction shape + instruction_k = 32 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM90Type.get( + shape_mnk.type.attribute, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + ) + return MmaF8Trait( + _cute_nvgpu_ir.make_sm90_mma( + ty, Boolean(False).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) + + +class MmaF8Trait(MmaTrait): + pass + + +#################################################################################################### +# +# SMEM layout atoms +# +#################################################################################################### + + +class SmemLayoutAtomKind(enum.Enum): + """ + Enum class for the kinds of SMEM layout atoms for SM90. + + Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can + be used to construct an SMEM layout using blocked product for operand A or B such that the + resulting layout is legal for both TMA and UMMA. + + Note that there are other ways of creating legal layouts for operand A and B. + """ + + MN_INTER = enum.auto() + MN_SW32 = enum.auto() + MN_SW64 = enum.auto() + MN_SW128 = enum.auto() + K_INTER = enum.auto() + K_SW32 = enum.auto() + K_SW64 = enum.auto() + K_SW128 = enum.auto() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..9128c67a24a7202713c354fb99b2891542f0c887 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py @@ -0,0 +1,510 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import ctypes +from functools import lru_cache +import itertools +import operator +from time import time +from typing import Union + +# MLIR modules imports +from cutlass._mlir import ir +import cutlass._mlir.dialects.cute as _cute_ir + +from cutlass.base_dsl.dsl import is_dynamic_expression +from cutlass.cutlass_dsl import JitArgAdapterRegistry + +# Local modules imports +from .typing import ( + AddressSpace, + Tensor, + Type, + Pointer, + Boolean, + Numeric, + Float4E2M1FN, + Int64, + Int32, + Int16, + Int8, + Uint64, + Uint32, + Uint16, + Uint8, + Float64, + Float32, + Float16, + BFloat16, + Float8E5M2, +) +from . import core +from .core import _Tensor as CoreTensor + + +class _Pointer(Pointer): + """Runtime representation of a pointer that can inter-operate with various data structures, + including numpy arrays and device memory. + + :param pointer: The pointer to the data + :type pointer: int or pointer-like object + :param dtype: Data type of the elements pointed to + :type dtype: Type + :param mem_space: Memory space where the pointer resides, defaults to generic + :type mem_space: _cute_ir.AddressSpace, optional + :param assumed_align: Assumed alignment of input pointer in bytes, defaults to None + :type assumed_align: int, optional + + :ivar _pointer: The underlying pointer + :ivar _dtype: Data type of the elements + :ivar _addr_space: Memory space of the pointer + :ivar _assumed_align: Alignment of the pointer in bytes + :ivar _desc: C-type descriptor for the pointer + :ivar _c_pointer: C-compatible pointer representation + """ + + def __init__( + self, + pointer, + dtype, + mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic, + assumed_align=None, + ): + self._pointer = pointer + self._dtype = dtype + self._addr_space = mem_space + + if assumed_align is None: + self._assumed_align = dtype.width // 8 + else: + self._assumed_align = assumed_align + + self._c_pointer = None + assert ( + int(self._pointer) % self._assumed_align == 0 + ), f"pointer must be {self._assumed_align} bytes aligned" + + def size_in_bytes(self) -> int: + self._desc = ctypes.c_void_p(int(self._pointer)) + return ctypes.sizeof(self._desc) + + def __get_mlir_types__(self): + return [self.mlir_type] + + def __c_pointers__(self): + if self._c_pointer is None: + self._desc = ctypes.c_void_p(int(self._pointer)) + self._c_pointer = ctypes.addressof(self._desc) + return [self._c_pointer] + + def __new_from_mlir_values__(self, values): + assert len(values) == 1 + return values[0] + + def __extract_mlir_values__(self): + return [self._c_pointer] + + # Move mlir Type out of __init__ to decouple with mlir Context + @property + def mlir_type(self) -> ir.Type: + return _cute_ir.PtrType.get( + self._dtype.mlir_type, self._addr_space, self._assumed_align + ) + + @property + def dtype(self) -> Type[Numeric]: + return self._dtype + + @property + def memspace(self): + return self._addr_space + + def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: + raise NotImplementedError("align is not supported in runtime") + + def verify(self, expected_py_type): + if expected_py_type is Pointer: + return True + elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer: + return True + + return False + + def __str__(self) -> str: + return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>" + + def __repr__(self): + return self.__str__() + + +class _Tensor(Tensor): + def __init__( + self, + tensor, + assumed_align=None, + ): + # If tensor is already a DLPack object, use it directly + if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"): + self._dlpack_data = tensor + else: + self._dlpack_data = tensor.__dlpack__() + self._dltensor_wrapper = None + self._assumed_align = assumed_align + self._is_dynamic = False + self._memref_desc = None + self._dtype = None + + @property + def __class__(self) -> Type[Tensor]: + # Cheat to let `type(_Tensor())` to return cute.Tensor + return Tensor + + @staticmethod + def lazily_load_dltensor(func): + """Decorator to lazily load the DLTensorWrapper. + + This decorator loads the DLTensorWrapper when needed, + avoiding overhead in the critical path of calling JIT functions. + """ + + def wrapper(self, *args, **kwargs): + if self._dltensor_wrapper is None: + self._dltensor_wrapper = _cute_ir.DLTensorWrapper(self._dlpack_data) + return func(self, *args, **kwargs) + + return wrapper + + @lazily_load_dltensor + def mark_layout_dynamic(self, leading_dim: int | None = None): + """Marks the tensor layout as dynamic based on the leading dimension. + + :param leading_dim: The leading dimension of the layout, defaults to None + :type leading_dim: int, optional + + When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout. + The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error + if the layout cannot be automatically deduced. + + When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the + stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent + with the existing layout by checking that the corresponding stride of that dimension is 1. + + Limitation: only support flat layout for now. Will work on supporting nested layout in the future. + + :return: The tensor with dynamic layout + :rtype: _Tensor + """ + self._dltensor_wrapper.mark_layout_dynamic(leading_dim) + return self + + @lazily_load_dltensor + def mark_compact_shape_dynamic( + self, + mode: int, + stride_order: tuple[int, ...] | None = None, + divisibility: int = 1, + ): + """Marks the tensor shape as dynamic and propagates dynamic and divisibility information to the corresponding strides. + + :param mode: The mode of the compact shape, defaults to 0 + :type mode: int + :param stride_order: Consistent with `torch.Tensor.dim_order`. Defaults to None. + Indicates the order of the modes (dimensions) if the current layout were converted to row-major order. + It starts from the outermost to the innermost dimension. + :type stride_order: tuple[int, ...], optional + :param divisibility: The divisibility constraint for the compact shape, defaults to 1 + :type divisibility: int, optional + :return: The tensor with dynamic compact shape + :rtype: _Tensor + + If ``stride_order`` is not provided, the stride ordering will be automatically deduced from the layout. + Automatic deduction is only possible when exactly one dimension has a stride of 1 (compact layout). + An error is raised if automatic deduction fails. + + If ``stride_order`` is explicitly specified, it does the consistency check with the layout. + + For example: + - Layout: (4,2):(1,4) has stride_order: (1,0) indicates the innermost dimension is 0(`4:1`), the outermost dimension is 1(`2:4`) + - Layout: (5,3,2,4):(3,1,15,30) has stride_order: (3,2,0,1) indicates the innermost dimension is 1(`3:1`), the outermost dimension is 3(`4:30`). + + Using `torch.Tensor.dim_order()` to get the stride order of the torch tensor. + .. code-block:: python + a = torch.empty(3, 4) + t = cute.runtime.from_dlpack(a) + t = t.mark_compact_shape_dynamic(mode=0, stride_order=a.dim_order()) + """ + self._dltensor_wrapper.mark_compact_shape_dynamic( + mode, stride_order, divisibility + ) + return self + + @property + @lazily_load_dltensor + def element_type(self) -> Type[Numeric]: + if self._dtype is None: + self._dtype = self._dltensor_wrapper.dtype + return self._dtype + + @element_type.setter + def element_type(self, new_type): + """Set the element type of the tensor. + + :warning: This API is added for narrow precision before we have a clean `recast_tensor` story. + + :note: It is only used for the case that frameworks don't natively support narrow precision but we get tensor + from frameworks with storage type like uint8. + + **Example**: + + .. code-block:: python + + # Create a tensor from a numpy array + import numpy as np + from cutlass.cute import from_dlpack + + # Create a tensor with Float32 elements + a = np.zeros(shape, dtype=np.uint8) + tensor = from_dlpack(a) + + # Change the element type to Float4E2M1FN even storage type is uint8 + tensor.element_type = cutlass.Float4E2M1FN + + src = from_dlpack(... data tensor ...) + # convert and initialize narrow precision tensor + cute.testing.convert(src, tensor) + """ + self._dtype = new_type + + @property + @lazily_load_dltensor + def memspace(self): + return self._dltensor_wrapper.address_space + + @property + @lazily_load_dltensor + def size_in_bytes(self) -> int: + return self._dltensor_wrapper.size_in_bytes() + + @property + @lazily_load_dltensor + def mlir_type(self) -> ir.Type: + return self._dltensor_wrapper.get_type( + self.element_type.mlir_type, self._assumed_align + ) + + @lazily_load_dltensor + def __str__(self) -> str: + return f"Tensor<0x{self._dltensor_wrapper.str}>" + + def __repr__(self): + return self.__str__() + + def __setitem__(self, crd, value): + raise TypeError(f"runtime._Tensor is not indexable") + + def __getitem__(self, crd): + raise TypeError(f"runtime._Tensor is not indexable") + + @property + @lazily_load_dltensor + def iterator(self): + return _Pointer( + self._dltensor_wrapper.data_ptr, + self.element_type, + self.memspace, + self._assumed_align, + ) + + @property + def layout(self): + raise NotImplementedError( + f"layout property is not supported in runtime, support in future" + ) + + @property + @lazily_load_dltensor + def shape(self): + return self._dltensor_wrapper.shape + + @property + @lazily_load_dltensor + def stride(self): + strides = self._dltensor_wrapper.stride + if strides is None: + strides = itertools.accumulate( + reversed(self.shape), func=operator.mul, initial=1 + ) + strides = tuple(reversed(list(strides)[:-1])) + + return strides + + @property + @lru_cache(maxsize=128, typed=True) + def leading_dim(self): + """Get the leading dimension of this Tensor. + + :return: The leading dimension index or indices + :rtype: int or tuple or None + + The return value depends on the tensor's stride pattern: + + * If a single leading dimension is found, returns an integer index + * If nested leading dimensions are found, returns a tuple of indices + * If no leading dimension is found, returns None + """ + return core.leading_dim(self.shape, self.stride) + + def fill(self, value: Numeric): + raise TypeError(f"fill function is not supported in runtime") + + @property + @lazily_load_dltensor + def data_ptr(self): + return self._dltensor_wrapper.data_ptr + + @lazily_load_dltensor + def __c_pointers__(self): + self._memref_desc = self._dltensor_wrapper.build_memref_desc( + self._assumed_align + ) + return [_cute_ir.pycapsule_get_pointer(self._memref_desc)] + + def __get_mlir_types__(self): + return [self.mlir_type] + + def __new_from_mlir_values__(self, values): + assert len(values) == 1 + assert isinstance(values[0], CoreTensor) + return CoreTensor(values[0].value, self._dtype) + + +def from_dlpack( + tensor_dlpack, + assumed_align=None, +) -> Tensor: + """Convert from tensor object supporting __dlpack__() to a CuTe Tensor. + + :param tensor_dlpack: Tensor object that supports the DLPack protocol + :type tensor_dlpack: object + :param assumed_align: Assumed alignment of the tensor (bytes), defaults to None, + if None, will use the element size bytes as the assumed alignment. + :type assumed_align: int, optional + :return: A CuTe Tensor object + :rtype: Tensor + + Examples: + .. code-block:: python + + import torch + from cutlass.cute.runtime import from_dlpack + x = torch.randn(100, 100) + y = from_dlpack(x) + y.shape + # (100, 100) + type(y) + # + """ + return _Tensor( + tensor_dlpack, + assumed_align=assumed_align, + ) + + +def make_ptr( + dtype: Type[Numeric], + value: Union[int, ctypes._Pointer], + mem_space: AddressSpace = AddressSpace.generic, + assumed_align=None, +) -> Pointer: + """Create a pointer from a memory address + + :param dtype: Data type of the pointer elements + :type dtype: Type[Numeric] + :param value: Memory address as integer or ctypes pointer + :type value: Union[int, ctypes._Pointer] + :param mem_space: Memory address space, defaults to AddressSpace.generic + :type mem_space: AddressSpace, optional + :param align_bytes: Alignment in bytes, defaults to None + :type align_bytes: int, optional + :return: A pointer object + :rtype: Pointer + + .. code-block:: python + + import numpy as np + import ctypes + + from cutlass import Float32 + from cutlass.cute.runtime import make_ptr + + # Create a numpy array + a = np.random.randn(16, 32).astype(np.float32) + + # Get pointer address as integer + ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + + # Create pointer from address + y = make_ptr(cutlass.Float32, ptr_address) + + # Check properties + print(y.element_type) + print(type(y)) # + """ + # check if value is int or ctypes.POINTER + if isinstance(value, int): + address_value = value + elif isinstance(value, ctypes._Pointer): + # get address value + address_value = ctypes.cast(value, ctypes.c_void_p).value + assert address_value is not None, "Pointer address is None" + else: + raise TypeError( + f"Expect int or ctypes.POINTER for value but got {type(value)=}" + ) + + return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align) + + +class TensorAdapter: + """ + Convert a DLPack protocol supported tensor/array to a cute tensor. + """ + + def __init__(self, arg): + self._arg = from_dlpack(arg).mark_layout_dynamic() + + def __new_from_mlir_values__(self, values): + return self._arg.__new_from_mlir_values__(values) + + def __c_pointers__(self): + return self._arg.__c_pointers__() + + def __get_mlir_types__(self): + return self._arg.__get_mlir_types__() + + +# ------------------------------------------------------------------------- +# Try to register_jit_arg_adapter for TensorAdapter +# ------------------------------------------------------------------------- + +try: # Register for numpy.ndarray + import numpy + + JitArgAdapterRegistry.register_jit_arg_adapter(numpy.ndarray)(TensorAdapter) +except ImportError: + pass # silent attempt, suppress error + +try: # Register for torch.Tensor + import torch + + JitArgAdapterRegistry.register_jit_arg_adapter(torch.Tensor)(TensorAdapter) +except ImportError: + pass # silent attempt, suppress error diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..88e0da048fc951da5091bcc38a6e6c92164f6d04 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py @@ -0,0 +1,610 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import functools +import inspect +import logging +import os +from enum import Enum +from inspect import isclass +from itertools import product +from time import time +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import cuda.bindings.driver as cuda_driver +import cuda.bindings.runtime as cuda_runtime +import numpy as np + +import cutlass._mlir.ir as ir +import cutlass.base_dsl.jit_executor +import cutlass.cute as cute +from cutlass._mlir.dialects import builtin, cf, nvvm, vector +from cutlass.cute import core, nvgpu +from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t, dsl_user_op + + +@dsl_user_op +def assert_(cond, msg=None, *, loc=None, ip=None): + cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip) + + +def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout): + if src.element_type.width == 4: + tv_layout = core.recast_layout(8, 4, tv_layout) + src = core.recast_tensor(src, dtype=t.Int8) + return src, tv_layout + + +def _maybe_recast_to_f4(input: core.TensorSSA, dtype: Type[core.Numeric]): + """Conditionally recasts the tensor to 4-bit type if the destination type is 4-bit. + + :param input: The input tensor to recast. + :param dtype: The target numeric type to potentially recast to. + :raises TypeError: If dtype is not a subclass of Numeric. + :return: A new tensor recast to 4-bit if dtype is 4-bit, otherwise returns self unchanged. + """ + if not isclass(dtype) or not issubclass(dtype, core.Numeric): + raise TypeError(f"dst_ty must be a type of Numeric, but got {dtype}") + + if dtype.width == 4: + recast_shape = core.recast_layout(4, 8, core.make_layout(input.shape)).shape + i4_vec = vector.bitcast( + T.vector(input.type.shape[0] * 2, T.i(4)), input.maybe_downcast() + ) + res_vect = builtin.unrealized_conversion_cast( + [T.vector(i4_vec.type.shape[0], dtype.mlir_type)], [i4_vec] + ) + return core.TensorSSA(res_vect, recast_shape, dtype) + return input + + +def _maybe_recast_from_f4(input: core.TensorSSA, src_dtype: Type[core.Numeric]): + """Conditionally recasts the tensor from 4-bit type if the source type is 4-bit. + + :param input: The input tensor to recast. + :param src_dtype: The source numeric type to potentially recast from. + :raises TypeError: If src_dtype is not a subclass of Numeric. + :return: A new tensor recast from 4-bit if src_dtype is 4-bit, otherwise returns self unchanged. + """ + if not isclass(src_dtype) or not issubclass(src_dtype, core.Numeric): + raise TypeError(f"src_ty must be a type of Numeric, but got {src_dtype}") + + if src_dtype.width == 4: + recast_shape = core.recast_layout(8, 4, core.make_layout(input.shape)).shape + i4_vec = builtin.unrealized_conversion_cast( + [T.vector(input.type.shape[0], T.i(4))], [input.maybe_downcast()] + ) + res_vect = vector.bitcast(T.vector(i4_vec.type.shape[0] // 2, T.i8()), i4_vec) + return core.TensorSSA(res_vect, recast_shape, core.Int8) + return input + + +@CuTeDSL.kernel +def _convert_kernel( + gSrc: core.Tensor, + gDst: core.Tensor, + cSrc: core.Tensor, + src_tv_layout: core.Layout, + dst_tv_layout: core.Layout, + src_shape: core.Shape, + src_ty, + dst_ty, +): + tidx = nvvm.read_ptx_sreg_tid_x(T.i32()) + bidx = nvvm.read_ptx_sreg_ctaid_x(T.i32()) + + cta_coord = (None, bidx) + # logical idx -> address + ctaSrc = gSrc[cta_coord] # (...,TileV,...) + ctaDst = gDst[cta_coord] # (...,TileV,...) + ctaCSrc = cSrc[cta_coord] # (...,TileV,...) + # print(f"ctaSrc = {ctaSrc.type}") + + # compose with CTA TV layout + # tid, vid -> address + tidfrgSrc = core.composition(ctaSrc, src_tv_layout) # (T,V) + tidfrgDst = core.composition(ctaDst, dst_tv_layout) # (T,V) + tidfrgCSrc = core.composition(ctaCSrc, src_tv_layout) # (T,V) + # print(f"tidfrgSrc = {tidfrgSrc.type}") + + # slice for threads + thr_coord = (tidx, None) + thrSrc = tidfrgSrc[thr_coord] # (V) + thrDst = tidfrgDst[thr_coord] # (V) + thrCSrc = tidfrgCSrc[thr_coord] # (V) + # print(f"thrSrc = {thrSrc.type}") + + # predicate + if core.elem_less(thrCSrc[0], src_shape): + # allocate fragments for gmem->rmem + frgSrc = core.make_fragment( + core.get(src_tv_layout, mode=[1]), gSrc.element_type + ) # (V) + frgDst = core.make_fragment( + core.get(dst_tv_layout, mode=[1]), gDst.element_type + ) # (V) + # print(f"frgSrc = {frgSrc.type}") + + # Move data to reg address space + copy_atom_load = core.make_copy_atom(nvgpu.CopyUniversalOp(), gSrc.element_type) + core.copy(copy_atom_load, thrSrc, frgSrc) + + vec_src = frgSrc.load() + vec_src = _maybe_recast_to_f4(vec_src, src_ty) + vec_dst = vec_src.to(dst_ty) + vec_dst = _maybe_recast_from_f4(vec_dst, dst_ty) + frgDst.store(vec_dst) + + # Copy the results back to c + copy_atom_stg = core.make_copy_atom(nvgpu.CopyUniversalOp(), gDst.element_type) + core.copy(copy_atom_stg, frgDst, thrDst) + + +@CuTeDSL.jit(preprocess=False) +def _convert( + src: core.Tensor, + dst: core.Tensor, + leading_mode: Constexpr, + elem_per_copy: Constexpr, +): + + # Step 1. figure proper tv_layout + src_ty = src.element_type + dst_ty = dst.element_type + + tv_layout = core.make_layout((128, elem_per_copy), stride=(elem_per_copy, 1)) + + # Step 2. maybe recast from f4 tensor + src, src_tv_layout = _maybe_recast_tensor_from_f4(src, tv_layout) + dst, dst_tv_layout = _maybe_recast_tensor_from_f4(dst, tv_layout) + src_shape = src.shape + # predicate tensor + idA = core.make_identity_tensor(src.shape) + + # Step 3. select a proper tiling pattern as (...,TileV, ...) + src_cta_tiler = [ + 1, + ] * core.rank(src.layout) + src_cta_tiler[leading_mode] = core.size(src_tv_layout) # (...,TileV,...) + dst_cta_tiler = [ + 1, + ] * core.rank(dst.layout) + dst_cta_tiler[leading_mode] = core.size(dst_tv_layout) # (...,TileV,...) + + # Step 4. partition input and output tensor by cta tiler. + gS = core.zipped_divide( + src, tuple(src_cta_tiler) + ) # ((...,TileV,...),(...,RestV,...)) + cS = core.zipped_divide( + idA, tuple(src_cta_tiler) + ) # ((...,TileV,...),(...,RestV,...)) + gD = core.zipped_divide( + dst, tuple(dst_cta_tiler) + ) # ((...,TileV,...),(...,RestV,...)) + # print(f"{gS.type=}") + + _convert_kernel( + gS, + gD, + cS, + src_tv_layout, + dst_tv_layout, + src_shape, + src_ty, + dst_ty, + ).launch( + grid=[core.size(gS, mode=[1]), 1, 1], + block=[core.size(src_tv_layout, mode=[0]), 1, 1], + ) + + +# Converts from src tensor to dst tensor, their logical shape are required to be the same. +# And when src or dst dtype is narrow precision(Float4E2M1FN/Float8E8M0FNU/Float8E4M3FN), the shape of +# their leading dimension should be 4(fp8)/8(fp4) element align. (nvgpu.cvt_fptrunc/cvt_fpext +# needs 32-bits aligned input/output) +def convert(src: core.Tensor, dst: core.Tensor): + assert len(src.shape) == len( + dst.shape + ), "Shape of src and dst tensors should be the same rank." + # find leading mode + leading_mode = [ + idx + for idx, (shape, stride) in enumerate(zip(src.shape, src.stride)) + if shape > 1 and stride == 1 + ] + if len(leading_mode) != 1: + raise ValueError(f"Leading mode should be unique, but got {leading_mode}") + leading_mode = leading_mode[0] + + elem_per_copy = 2 + + if src.element_type.width == 4 or dst.element_type.width == 4: + elem_per_copy = 8 + elif src.element_type.width == 8 or dst.element_type.width == 8: + elem_per_copy = 4 + assert ( + src.shape[leading_mode] % elem_per_copy == 0 + and dst.shape[leading_mode] % elem_per_copy == 0 + ) + _convert(src, dst, leading_mode, elem_per_copy) + + +######################################### +# Testing utilities +######################################### + + +def sample_pytest(rand_cfg=None): + """ + Decorator to randomly sample pytest parametrized tests. + rand_cfg: Tuple[int, float] - (random_seed, sample_ratio) + Sampling is disabled when: + - A specific test is selected (via -k or direct test path) + - Not running under pytest + """ + import functools + import os + import random + import sys + + import pytest + + seed, sample_ratio = rand_cfg + random.seed(seed) + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if rand_cfg is not None and "PYTEST_CURRENT_TEST" in os.environ: + # Check if test was explicitly selected like ::test_name[param1-param2-...] + if "-k" in sys.argv or any(".py::" in arg for arg in sys.argv): + # Test was explicitly selected, don't skip + return func(*args, **kwargs) + + if random.uniform(0.0, 1.0) > sample_ratio: + pytest.skip(f"Randomly skipped (sampling ratio: {sample_ratio})") + return func(*args, **kwargs) + + return wrapper + + return decorator + + +######################################### +# Benchmarking utilities +######################################### + + +class JitArguments: + """ + A type to hold both args and kwargs for passing to a kernel while benchmarking. + """ + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +def _cuda_success( + err: Union[tuple, cuda_runtime.cudaError_t, cuda_driver.CUresult], message: str +): + """ + Helper function to check CUDA API errors. + """ + if isinstance(err, tuple): + _cuda_success(err[0], message) + elif isinstance(err, cuda_runtime.cudaError_t): + error_message = cuda_runtime.cudaGetErrorString(err)[1].decode("utf-8") + if err != cuda_runtime.cudaError_t.cudaSuccess: + raise RuntimeError(f"{message} : {error_message}") + elif isinstance(err, cuda_driver.CUresult): + if err != cuda_driver.CUresult.CUDA_SUCCESS: + error_message = cuda_driver.cuGetErrorString(err)[1].decode("utf-8") + raise RuntimeError(f"{message} : {error_message}") + else: + raise TypeError( + f"{err} is an unexpected type : it should be a cudaError_t or CUresult" + ) + + +def _does_kernel_use_stream( + kernel: Callable, stream: cuda_driver.CUstream, *args, **kwargs +): + """ + This function checks if the kernel uses the provided non-default stream. + It does this by capturing the stream and then checking if any kernels were launched. + :param kernel: The kernel to check + :type kernel: Callable + :param stream: The stream to check + :type stream: cuda_driver.CUstream + :return: True if the kernel uses the stream, False otherwise + :rtype: bool + """ + + assert int(stream) != int( + cuda_driver.CUstream_flags.CU_STREAM_DEFAULT + ), "Stream must be a non-default stream" + + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on stream capture") + + kernel(*args, **kwargs) + + err, graph = cuda_runtime.cudaStreamEndCapture(stream) + _cuda_success(err, "Error on stream capture") + + # Get number of nodes in warmup graph to check it matches what is expected + err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(graph) + _cuda_success(err, "Error on querying graph") + return num_nodes > 0 + + +def benchmark( + callable: Callable, + *, + warmup_iterations: int = 10, + iterations: int = 100, + stream: Optional[cuda_driver.CUstream] = None, + kernel_arguments: Optional[JitArguments] = None, + workspace_generator: Optional[Callable[[], JitArguments]] = None, + workspace_count: int = 1, + use_cuda_graphs: bool = False, +) -> float: + """Benchmarks a callable function with the specified parameters. + + For example, + .. code-block:: python + + from cutlass.cute.testing import benchmark + + @cute.jit + def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda_driver.CUstream): + # contents of the function + pass + + time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream) + warmup_iterations=10, iterations=100 + stream=stream) + + To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator + parameters to cycle through a number of different workspaces. + + .. code-block:: python + + from cutlass.cute.testing import benchmark + + @cute.jit + def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor): + # contents of the function + pass + + def workspace_generator(): + # create a, b, and c + return JitArguments(a, b, c) + + time_us = benchmark(user_function, + workspace_generator=workspace_generator, + workspace_count=10, + warmup_iterations=10000, + iterations=1000) + + To benchmark you may always configure the function being profiled (callable), the warmup iterations, and + the number of profiling iterations. + + Whenever the kernel being benchmarked runs in a non-default stream, the stream must be provided through the stream parameter. + + To use CUDA graphs, the callable must be a compiled @cute.jit annotated function. + When using CUDA graphs, the kernel must be launched in a non-default stream. + + :param callable: The function to benchmark + :type callable: Callable + :param warmup_iterations: Number of warmup iterations, defaults to 10 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations, defaults to 100 + :type iterations: int, optional + :param stream: Stream kernel is launched in, defaults to CUDA stream default + :type stream: CUstream, None + :param kernel_arguments: Kernel arguments to launch callable with, defaults to None + :type kernel_arguments: JitArguments, None + :param workspace_generator: Function that returns kernel arguments, defaults to None + :type workspace_generator: Callable + :param workspace_count: Number of workspaces (arguments) to loop through, looping through enough workspaces will keep the L2 cache cold + :type workspace_count: int, optional + :param use_cuda_graphs: Whether to use cuda graphs, defaults to False + :type use_cuda_graphs: bool, optional + + :return: The benchmark time in microseconds + :rtype: float + """ + + if stream is None: + stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT) + + if workspace_count < 1: + raise ValueError("workspace_count must be at least 1") + + time_us = float("nan") + if workspace_generator == None: + # If no workspace generator is provided, we need a single workspace + if workspace_count != 1: + raise ValueError("Need a single workspace if not providing a generator") + + # If no workspace generator is provided, we need a kernel_argument + if kernel_arguments == None: + raise ValueError( + "Please pass a kernel argument if not providing a generator" + ) + workspace_generator = lambda: kernel_arguments + + workspaces = [workspace_generator() for _ in range(workspace_count)] + + for workspace in workspaces: + if type(workspace) != JitArguments: + raise TypeError( + "workspace_generator and/or kernel_arguments should use JitArguments type" + ) + + def _loop_and_call_kernel(iterations: int, workspace_index: int = 0): + for _ in range(iterations): + current_workspace = workspaces[workspace_index] + callable(*current_workspace.args, **current_workspace.kwargs) + workspace_index = (workspace_index + 1) % workspace_count + return workspace_index + + # Create CUDA events for timing + err, start_event = cuda_driver.cuEventCreate( + cuda_driver.CUevent_flags.CU_EVENT_DEFAULT + ) + _cuda_success(err, "Error on creating event") + err, end_event = cuda_driver.cuEventCreate( + cuda_driver.CUevent_flags.CU_EVENT_DEFAULT + ) + _cuda_success(err, "Error on creating event") + + elapsed_time = float("nan") + + if use_cuda_graphs: + # Check if the callable is a JitExecutor + if not isinstance(callable, cutlass.base_dsl.jit_executor.JitExecutor): + raise TypeError("Function must be precompiled to be used with CUDA Graphs") + + # Check if the stream is a non-default stream + if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT): + raise ValueError( + "Measuring with CUDA Graphs requires executing in a non-default stream" + ) + + workspace_index = 0 + + # Capture warmup graph + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on stream capture") + + workspace_index = _loop_and_call_kernel(warmup_iterations) + err, gwarm = cuda_runtime.cudaStreamEndCapture(stream) + _cuda_success(err, "Error on stream capture") + + # Get number of nodes in warmup graph to check it matches what is expected + err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(gwarm) + _cuda_success(err, "Error on querying graph") + # Assertion is >= since we may launch multiple kernels in one host function + if num_nodes < warmup_iterations: + raise ValueError( + f"CUDA stream passed to benchmark does not match the stream the kernel was launched in" + ) + + # Capture profiling graph + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on stream capture") + _loop_and_call_kernel(iterations, workspace_index) + err, gprofile = cuda_runtime.cudaStreamEndCapture(stream) + _cuda_success(err, "Error on stream capture") + + # Instantiate graphs + err, gwarm = cuda_runtime.cudaGraphInstantiate(gwarm, 0) + _cuda_success(err, "Error on graph instantiation") + err, gprofile = cuda_runtime.cudaGraphInstantiate(gprofile, 0) + _cuda_success(err, "Error on graph instantiation") + + # Launch warmup graph + err = cuda_runtime.cudaGraphLaunch(gwarm, stream) + _cuda_success(err, "Error on graph launch") + + # Record start time + err = cuda_driver.cuEventRecord(start_event, stream) + _cuda_success(err, "Error on recording event") + + # Launch profiling graph + err = cuda_runtime.cudaGraphLaunch(gprofile, stream) + _cuda_success(err, "Error on graph launch") + + # Record end time + err = cuda_driver.cuEventRecord(end_event, stream) + _cuda_success(err, "Error on recording event") + err = cuda_driver.cuEventSynchronize(end_event) + _cuda_success(err, "Error on synchronizing event") + + # Get elapsed time + err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) + _cuda_success(err, "Error on querying event") + + # Destroy graphs + err = cuda_runtime.cudaGraphExecDestroy(gwarm) + _cuda_success(err, "Error on destroying graph") + err = cuda_runtime.cudaGraphExecDestroy(gprofile) + _cuda_success(err, "Error on destroying graph") + + else: + + if int(stream) != int( + cuda_driver.CUstream_flags.CU_STREAM_DEFAULT + ) and not _does_kernel_use_stream( + callable, stream, *workspaces[0].args, **workspaces[0].kwargs + ): + raise ValueError( + "CUDA stream passed to benchmark does not match the stream the kernel was launched in" + ) + + # Not using graphs + # Warmup + workspace_index = _loop_and_call_kernel(warmup_iterations) + # Record start event + err = cuda_driver.cuEventRecord(start_event, stream) + _cuda_success(err, "Error on recording event") + _loop_and_call_kernel(iterations, workspace_index) + # Record end event + err = cuda_driver.cuEventRecord(end_event, stream) + _cuda_success(err, "Error on recording event") + # Synchronize end event + err = cuda_driver.cuEventSynchronize(end_event) + _cuda_success(err, "Error on synchronizing event") + err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) + _cuda_success(err, "Error on querying event") + + # Destroy events + err = cuda_driver.cuEventDestroy(start_event) + _cuda_success(err, "Error on destroying event") + err = cuda_driver.cuEventDestroy(end_event) + _cuda_success(err, "Error on destroying event") + + return elapsed_time / iterations * 1e3 + + +def get_workspace_count( + one_workspace_bytes: int, warmup_iterations: int, iterations: int +) -> int: + """Calculate the number of workspaces needed to fill L2 cache. + + :param one_workspace_bytes: Size of one workspace in bytes + :type one_workspace_bytes: int + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations + :type iterations: int + :return: Number of workspaces needed + :rtype: int + """ + num_l2_cache_bytes = cutlass.utils.HardwareInfo().get_l2_cache_size_in_bytes() + return max( + 1, + min( + warmup_iterations + iterations, # Don't create more workspaces than needed + (num_l2_cache_bytes + one_workspace_bytes - 1) + // one_workspace_bytes, # Ceiling division + ), + ) + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..215e71d98fc39c192c784c99bb8ef14f6e2f55d9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from abc import ABC, abstractmethod +from typing import ForwardRef, Tuple, Union, Any, Type, List + +from cutlass.base_dsl.typing import * + +from cutlass._mlir import ir +import cutlass._mlir.extras.types as T +from cutlass._mlir.dialects.cute import AddressSpace + + +Int = Union[int, Integer] + + +ScaledBasis = ForwardRef("ScaledBasis") + + +IntTuple = Union[Int, Tuple["IntTuple", ...]] +Shape = Union[Int, Tuple["Shape", ...]] +Stride = Union[Int, ScaledBasis, Tuple["Stride", ...]] +Coord = Union[Int, None, Tuple["Coord", ...]] + + +class Layout(ir.Value): + def __init__(self, op_result): + super().__init__(op_result) + + def __str__(self): ... + + def get_hier_coord(self, idx) -> Coord: + """Return the (hierarchical) ND logical coordinate corresponding to the linear index""" + ... + + @property + def shape(self, *, loc=None, ip=None) -> Shape: ... + + @property + def stride(self, *, loc=None, ip=None) -> Stride: ... + + +Tile = Union[Int, None, Layout, Tuple["Tile", ...]] + +# XTuple is super set of above types +XTuple = Union[IntTuple, Shape, Stride, Coord, Tile] + +Tiler = Union[Shape, Layout, Tile] + + +class Pointer(ABC): + """ + Abstract base class for CuTe jit function and runtime _Pointer + """ + + @property + def value_type(self) -> Type[Numeric]: + return self.dtype + + @property + def dtype(self) -> Type[Numeric]: ... + + def align(self, min_align: int) -> "Pointer": ... + + def __get_mlir_types__(self) -> List[ir.Type]: ... + + def __extract_mlir_values__(self) -> List[ir.Value]: ... + + def __new_from_mlir_values__(self, values) -> "Pointer": ... + + +class Tensor(ABC): + """ + Abstract base class for CuTe jit function and runtime _Tensor + + A CuTe Tensor is iterator with layout + + :Examples: + + Create tensor from torch.tensor with Host Runtime: + + .. code-block:: python + + >>> import torch + >>> from cutlass.cute.runtime import from_dlpack + >>> mA = from_dlpack(torch.tensor([1, 3, 5], dtype=torch.int32)) + >>> mA.shape + (3,) + >>> mA.stride + (1,) + >>> mA.layout + (3,):(1,) + + Define JIT function: + + .. code-block:: python + + @cute.jit + def add(a: Tensor, b: Tensor, res: Tensor): ... + + Call JIT function from python: + + .. code-block:: python + + >>> import torch + >>> a = torch.tensor([1, 3, 5], dtype=torch.int32) + >>> b = torch.tensor([2, 4, 6], dtype=torch.int32) + >>> c = torch.zeros([3], dtype=torch.int32) + >>> mA = from_dlpack(a) + >>> mB = from_dlpack(b) + >>> mC = from_dlpack(c) + >>> add(mA, mB, mC) + >>> c + tensor([3, 7, 11], dtype=torch.int32) + """ + + def __str__(self): ... + + @abstractmethod + def __getitem__(self, idx) -> Union["Tensor", ir.Value, IntTuple]: ... + + @abstractmethod + def __setitem__(self, idx, value): ... + + @property + @abstractmethod + def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: ... + + @element_type.setter + def element_type(self, new_type): ... + + @property + @abstractmethod + def memspace(self) -> AddressSpace: ... + + @property + @abstractmethod + def iterator(self): ... + + @property + def layout(self) -> Union[Layout, "ComposedLayout"]: ... + + @property + def shape(self) -> Shape: ... + + def load(self, *, loc=None, ip=None) -> "TensorSSA": ... + + def store(self, data: "TensorSSA", *, loc=None, ip=None): ... + + def mark_layout_dynamic(self, leading_dim: int | None = None) -> "Tensor": ... + + def mark_compact_shape_dynamic( + self, + mode: int, + stride_order: tuple[int, ...] | None = None, + divisibility: int = 1, + ) -> "Tensor": ... + + @abstractmethod + def fill(self, value: Numeric) -> None: ... + + +__all__ = [ + "Coord", + "Numeric", + "Integer", + "Boolean", + "Int8", + "Int16", + "Int32", + "Int64", + "Uint8", + "Uint16", + "Uint32", + "Uint64", + "Float", + "Float16", + "BFloat16", + "TFloat32", + "Float32", + "Float64", + "Float8E5M2", + "Float8E4M3FN", + "Float8E4M3B11FNUZ", + "Float8E4M3", + "Float8E8M0FNU", + "Float4E2M1FN", + "Float6E2M3FN", + "Float6E3M2FN", + "IntTuple", + "Layout", + "Pointer", + "Shape", + "Stride", + "Tensor", + "Tile", + "Tiler", + "XTuple", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/impl_utils.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/impl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb9b5207144a11665449fac431fcbe2bd8f49bd --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/impl_utils.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + + +def check_value_in( + value, possible_values: list, value_description: str, prefix="" +) -> None: + if value not in possible_values: + err_msg = prefix + if err_msg != "": + err_msg += ": " + err_msg += f"invalid {value_description}, got {value}, must be one of {possible_values}" + raise ValueError(err_msg) + + +def check_type_in(ty, possible_types: list, type_description: str, prefix="") -> None: + if not isinstance(ty, type): + ty = type(ty) + if ty not in possible_types: + err_msg = prefix + if err_msg != "": + err_msg += ": " + err_msg += f"invalid type for {type_description}, got {ty}, must be one of {possible_types}" + raise TypeError(err_msg) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7df24dd6bb6a5e42ebf5bad0e785cf77589bbbc6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/__init__.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .helpers import ( + Agent, + CooperativeGroup, + PipelineOp, + SyncObject, + MbarrierArray, + NamedBarrier, + TmaStoreFence, + PipelineUserType, + PipelineState, + make_pipeline_state, + pipeline_init_wait, + arrive, + arrive_unaligned, + wait, + wait_unaligned, + arrive_and_wait, + sync, +) + +from .sm90 import ( + PipelineAsync, + PipelineCpAsync, + PipelineTmaAsync, + PipelineTmaMultiConsumersAsync, + PipelineTmaStore, + PipelineProducer, + PipelineConsumer, +) + +from .sm100 import ( + PipelineTmaUmma, + PipelineAsyncUmma, + PipelineUmmaAsync, +) + +__all__ = [ + "Agent", + "CooperativeGroup", + "PipelineOp", + "SyncObject", + "MbarrierArray", + "NamedBarrier", + "TmaStoreFence", + "PipelineUserType", + "PipelineState", + "PipelineAsync", + "PipelineCpAsync", + "PipelineTmaAsync", + "PipelineTmaUmma", + "PipelineTmaMultiConsumersAsync", + "PipelineAsyncUmma", + "PipelineUmmaAsync", + "PipelineTmaStore", + "PipelineProducer", + "PipelineConsumer", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b94899435224ceda4bd152944e9a4b9bc2e911 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/helpers.py @@ -0,0 +1,652 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Union +import warnings + +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, Int32, Int64, if_generate +from cutlass._mlir.dialects import llvm +import cutlass._mlir.dialects.cute as _cute_ir + + +############################################################################## +# Agent class +############################################################################## + + +class Agent(enum.Enum): + """ + Agent indicates what is participating in the pipeline synchronization. + """ + + # Arbitrary grouping of N threads + Thread = enum.auto() + # Same as AsyncThread, but includes all threads in the block + ThreadBlock = enum.auto() + # Same as AsyncThread, but includes all threads in the cluster + ThreadBlockCluster = enum.auto() + + +class CooperativeGroup: + """ + CooperativeGroup contains size and alignment restrictions for an Agent. + """ + + def __init__(self, agent: Agent, size: int = 1, alignment: int = 1): + if agent is Agent.Thread: + assert size > 0 + if size == 32: + assert ( + size == alignment + ), "Error: Alignment does not match number of threads in a warp." + elif size == 128: + assert ( + size == alignment + ), "Error: Alignment does not match number of threads in a warpgroup." + elif agent is Agent.ThreadBlock: + raise NotImplementedError("Error: Not yet supported.") + elif agent is Agent.ThreadBlockCluster: + raise NotImplementedError("Error: Not yet supported.") + else: + # Should never reach this state + size = 0 + + if size <= 0: + raise ValueError( + "Error: The number of threads in a CooperativeGroup must be more than 0." + ) + + # Size indicates how many threads are participating in this CooperativeGroup + self.size = size + # Agent indicates the type of thread group + self.agent = agent + + +class PipelineOp(enum.Enum): + """ + PipelineOp assigns an operation to an agent corresponding to a specific hardware feature. + """ + + # async-threads + AsyncThread = enum.auto() + # Blackwell (SM100a) MMA instruction + TCGen05Mma = enum.auto() + # Tensor Memory Accelerator load + TmaLoad = enum.auto() + # TMA Store consuming smem produced by AsyncThread + TmaStore = enum.auto() + # Composite of multiple PipelineOps + Composite = enum.auto() + # Async load without TMA + AsyncLoad = enum.auto() + + +def _get_pipeline_op(type_str): + return PipelineOp(type_str) + + +############################################################################## +# SyncObject class +############################################################################## + + +class SyncObject(ABC): + """Abstract base class for hardware synchronization primitives. + + This class defines the interface for different types of hardware synchronization + mechanisms including shared memory barriers, named barriers, and fences. + """ + + @abstractmethod + def arrive(self) -> None: + pass + + @abstractmethod + def wait(self) -> None: + pass + + @abstractmethod + def arrive_and_wait(self) -> None: + pass + + @abstractmethod + def arrive_and_drop(self) -> None: + pass + + @abstractmethod + def get_barrier(self) -> Union[cute.Pointer, int, None]: + pass + + @abstractmethod + def max(self) -> Union[int, None]: + pass + + +class MbarrierArray(SyncObject): + """ + MbarrierArray implements an abstraction for an array of smem barriers. + """ + + def __init__( + self, + barrier_storage: cute.Pointer, + num_stages: int, + agent: tuple[PipelineOp, CooperativeGroup], + tx_count: int = 0, + ) -> None: + self.barrier_storage = barrier_storage + self.tx_count = tx_count + self.num_stages = num_stages + self.op_type, self.cg = agent + self.arrive_count = self.cg.size + + if self.num_stages <= 0: + raise ValueError("Error: Mbarrier stage count must be greater than 0.") + if self.arrive_count <= 0: + raise ValueError("Error: Mbarrier arrive count must be greater than 0.") + if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0: + raise ValueError( + "Error: Mbarrier tx count must not be less than 0 for TMA ops." + ) + + # Store mbarrier base pointer + self.mbarrier_base = self.barrier_storage + + # Mbarrier initialization in constructor + self.mbarrier_init() + + def recast_to_new_op_type(self, new_op_type: PipelineOp) -> "MbarrierArray": + """ + Creates a copy of MbarrierArray with a different op_type without re-initializing barriers + """ + # Create new instance without initialization + new_mbarrier_array = object.__new__(MbarrierArray) + + # Copy all attributes directly + new_mbarrier_array.barrier_storage = self.barrier_storage + new_mbarrier_array.op_type = new_op_type + new_mbarrier_array.cg = self.cg + new_mbarrier_array.num_stages = self.num_stages + new_mbarrier_array.tx_count = self.tx_count + new_mbarrier_array.arrive_count = self.arrive_count + new_mbarrier_array.mbarrier_base = self.mbarrier_base + return new_mbarrier_array + + # Mbarrier initialization + def mbarrier_init(self) -> None: + """ + Initializes an array of mbarriers using warp 0. + """ + + def then_body(): + for index in range(self.num_stages): + cute.arch.mbarrier_init(self.get_barrier(index), self.arrive_count) + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + if_generate(warp_idx == 0, then_body) + + def arrive( + self, + index: int, + dst: int, + cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, + ) -> None: + """Select the arrive corresponding to this MbarrierArray's PipelineOp. + + :param index: Index of the mbarrier in the array to arrive on + :type index: int + :param dst: Destination parameter for selective arrival, which can be either a mask or destination cta rank. + When None, both ``TCGen05Mma`` and ``AsyncThread`` will arrive on their local mbarrier. + - For ``TCGen05Mma``, ``dst`` serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs + in the cluster with rank = 0, 1, and 3). + - For ``AsyncThread``, ``dst`` serves as a destination cta rank (e.g., 3 means threads will arrive on + the mbarrier with rank = 3 in the cluster). + :type dst: int | None + :param cta_group: CTA group for ``TCGen05Mma``, defaults to None for other op types + :type cta_group: ``cute.nvgpu.tcgen05.CtaGroup``, optional + """ + if self.op_type is PipelineOp.AsyncThread: + self.arrive_mbarrier(index, dst) + elif self.op_type is PipelineOp.TCGen05Mma: + assert ( + cta_group is not None + ), "Error: CTA group must be provided for TCGen05Mma." + self.arrive_tcgen05mma(index, dst, cta_group) + elif self.op_type in [PipelineOp.TmaLoad]: + self.arrive_and_expect_tx(index, self.tx_count) + elif self.op_type is PipelineOp.AsyncLoad: + self.arrive_cp_async_mbarrier(index) + else: + assert ( + False + ), f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}." + + def arrive_mbarrier(self, index: int, dst_rank: Optional[int] = None) -> None: + if dst_rank is None: + cute.arch.mbarrier_arrive(self.get_barrier(index)) + else: + cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank) + + def arrive_cp_async_mbarrier(self, index: int): + cute.arch.cp_async_mbarrier_arrive_noinc(self.get_barrier(index)) + + def arrive_tcgen05mma( + self, index: int, mask: Optional[int], cta_group: cute.nvgpu.tcgen05.CtaGroup + ) -> None: + if mask is None: + with cute.arch.elect_one(): + cute.nvgpu.tcgen05.commit(self.get_barrier(index)) + else: + with cute.arch.elect_one(): + cute.nvgpu.tcgen05.commit(self.get_barrier(index), mask, cta_group) + + def arrive_and_expect_tx(self, index: int, tx_count: int) -> None: + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(self.get_barrier(index), tx_count) + + def try_wait(self, index: int, phase: int) -> Boolean: + return cute.arch.mbarrier_try_wait(self.get_barrier(index), phase) + + def wait(self, index: int, phase: int) -> None: + cute.arch.mbarrier_wait(self.get_barrier(index), phase) + + def arrive_and_wait( + self, + index: int, + phase: int, + dst: int, + cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, + ) -> None: + arrive(index, dst, cta_group) + wait(index, phase) + + def arrive_and_drop(self) -> None: + raise NotImplementedError("Error: Not yet supported.") + + def get_barrier(self, index: int) -> cute.Pointer: + return self.mbarrier_base + index + + def max(self) -> int: + # Transaction barriers have a maximum arrive count of 511 (2^9 - 1). + # Non-transaction barriers have a maximum arrive count of 1,048,575 (2^20 - 1). + return 511 + + def __extract_mlir_values__(self): + return [self.barrier_storage] + + def __new_from_mlir_values__(self, values): + return MbarrierArray( + values[0], self.num_stages, (self.op_type, self.cg), self.tx_count + ) + + +@dataclass(frozen=True) +class NamedBarrier(SyncObject): + """ + NamedBarrier is an abstraction for named barriers managed by hardware. + There are 16 named barriers available, with barrier_ids 0-15. + + See the `PTX documentation `__. + """ + + barrier_id: int + num_threads: int + + def __post_init__(self) -> None: + if self.barrier_id < 0 or self.barrier_id >= 16: + raise ValueError("Error: NamedBarrier ID must be between 0 and 16.") + if self.barrier_id == 0: + warnings.warn( + "NamedBarrier ID 0 is by other driver APIs (i.e. sync_threads()) and should not be used." + ) + + def arrive(self) -> None: + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive( + barrier_id=self.barrier_id, number_of_threads=self.num_threads + ) + + def arrive_unaligned(self) -> None: + """ + The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA. + """ + llvm.inline_asm( + None, + [Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()], + "barrier.arrive $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + def wait(self) -> None: + """ + NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait. + If synchronizing two warps in a producer/consumer pairing, the arrive count would be + 32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer + or consumer are counted for mbarriers, while all threads participating in the sync + are counted for NamedBarriers. + """ + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + self.arrive_and_wait() + + def wait_unaligned(self) -> None: + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + llvm.inline_asm( + None, + [Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()], + "barrier.sync $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + def arrive_and_wait(self) -> None: + cute.arch.barrier( + barrier_id=self.barrier_id, number_of_threads=self.num_threads + ) + + def arrive_and_drop(self) -> None: + raise NotImplementedError("Error: Not supported.") + + def sync(self) -> None: + cute.arch.barrier(barrier_id=self.barrier_id) + + def get_barrier(self) -> int: + return self.barrier_id + + def max(self) -> int: + # Transaction barriers have a maximum arrive count of 4095 (2^12 - 1). + return 4095 + + +class TmaStoreFence(SyncObject): + """ + TmaStoreFence is used for a multi-stage epilogue buffer. + """ + + def __init__(self, num_stages: int = 0) -> None: + if num_stages <= 0: + raise ValueError("Mbarrier stage count must be greater than 0.") + + self.num_stages = num_stages + + def arrive(self) -> None: + cute.arch.cp_async_bulk_commit_group() + + def wait(self) -> None: + cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True) + + def arrive_and_wait(self) -> None: + self.arrive() + self.wait() + + def arrive_and_drop(self) -> None: + raise NotImplementedError("Error: Not supported.") + + # TmaStoreFence doesn't have mbarriers + def get_barrier(self) -> None: + assert ( + False + ), "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier." + + def max(self) -> None: + raise NotImplementedError("Error: Not supported.") + + def tail(self) -> None: + cute.arch.cp_async_bulk_wait_group(0, read=True) + + +############################################################################## +# PipelineState class +############################################################################## + + +class PipelineUserType(enum.Enum): + Producer = enum.auto() + Consumer = enum.auto() + + +class PipelineState: + """ + Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. + """ + + def __init__(self, stages: int, count, index, phase): + self._stages = stages + self._count = count + self._index = index + self._phase = phase + + def clone(self) -> "PipelineState": + return PipelineState(self.stages, self._count, self.index, self.phase) + + @property + def index(self) -> Int32: + return self._index + + @property + def count(self) -> Int32: + return self._count + + @property + def stages(self) -> int: + return self._stages + + @property + def phase(self) -> Int32: + return self._phase + + def reset_count(self): + self._count = Int32(0) + + def advance(self): + self._index += 1 + self._count += 1 + + def then_body(index, phase): + new_index = Int32(0) + new_phase = phase ^ 1 + return new_index, new_phase + + def else_body(index, phase): + return index, phase + + self._index, self._phase = if_generate( + self._index == self.stages, + then_body, + else_body, + [self.index, self.phase], + [Int32, Int32], + ) + + def reverse(self): + self._index -= 1 + self._count -= 1 + + def then_body(index, phase): + new_index = Int32(self.stages - 1) + new_phase = phase ^ 1 + return new_index, new_phase + + def else_body(index, phase): + return index, phase + + self._index, self._phase = if_generate( + self._index == -1, + then_body, + else_body, + [self.index, self.phase], + [Int32, Int32], + ) + + def __get_mlir_types__(self): + return [self._count.type, self._index.type, self._phase.type] + + def __extract_mlir_values__(self): + count = self._count + index = self._index + phase = self._phase + return [count.ir_value(), index.ir_value(), phase.ir_value()] + + # This can be overridden by derived classes + def __new_from_mlir_values__(self, values): + return PipelineState( + self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2]) + ) + + +def make_pipeline_state(type: PipelineUserType, stages: int): + """ + Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. + """ + if type is PipelineUserType.Producer: + return PipelineState( + stages, + Int32(0), + Int32(0), + Int32(1), + ) + elif type is PipelineUserType.Consumer: + return PipelineState( + stages, + Int32(0), + Int32(0), + Int32(0), + ) + else: + assert ( + False + ), "Error: invalid PipelineUserType specified for make_pipeline_state." + + +############################################################################## +# Helper functions +############################################################################## + + +def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): + """ + Fences the mbarrier init and syncs the threadblock or cluster + """ + cute.arch.mbarrier_init_fence() + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # If not using clusters, sync the threadblock + _sync(Agent.ThreadBlock) + else: + # If using clusters, sync the cluster + _sync(Agent.ThreadBlockCluster) + + +def _sync(group: Agent): + """ + Syncs all threads within an agent. + """ + if group is Agent.Thread: + raise NotImplementedError("Error: Not supported.") + elif group is Agent.ThreadBlock: + cute.arch.sync_threads() + elif group is Agent.ThreadBlockCluster: + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + else: + assert ( + False + ), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." + + +def _mbarrier_i64_to_ptr(val: Int64) -> cute.Pointer: + """ + Converts a smem pointer of type Int64 to cute.Pointer with 8B alignment + """ + return cute.make_ptr( + Int64, + val.ir_value(), + mem_space=_cute_ir.AddressSpace.smem, + assumed_align=8, + ) + + +# NamedBarrier free functions +def arrive(barrier_id: int, num_threads: int): + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive(barrier_id=barrier_id, number_of_threads=num_threads) + + +def arrive_unaligned(barrier_id: int, num_threads: int): + """ + The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA. + """ + llvm.inline_asm( + None, + [Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()], + "barrier.arrive $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def wait(barrier_id: int, num_threads: int): + """ + NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait. + If synchronizing two warps in a producer/consumer pairing, the arrive count would be + 32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer + or consumer are counted for mbarriers, while all threads participating in the sync + are counted for NamedBarriers. + """ + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + arrive_and_wait() + + +def wait_unaligned(barrier_id: int, num_threads: int): + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + llvm.inline_asm( + None, + [Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()], + "barrier.sync $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def arrive_and_wait(barrier_id: int, num_threads: int): + cute.arch.barrier(barrier_id=barrier_id, number_of_threads=num_threads) + + +def sync(barrier_id: int = 0): + cute.arch.barrier(barrier_id=barrier_id) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm100.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm100.py new file mode 100644 index 0000000000000000000000000000000000000000..2feed8cc0f1e702557f0c2b21b7582651a6405b8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm100.py @@ -0,0 +1,453 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Union +import warnings + +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, if_generate + +from cutlass.pipeline import ( + Agent, + CooperativeGroup, + PipelineOp, + PipelineState, + pipeline_init_wait, + PipelineAsync, +) + +############################################################################## +# Pipeline classes +############################################################################## + + +@dataclass(frozen=True) +class PipelineTmaUmma(PipelineAsync): + """ + PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops). + """ + + is_leader_cta: bool + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout): + """ + Computes a mask for signaling arrivals to multicasting threadblocks. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2 + ) + tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1 + ) + + block_in_cluster_coord_vmnk_peer = ( + cta_in_cluster_coord_vmnk[0] ^ 1, + *cta_in_cluster_coord_vmnk[1:], + ) + tma_mcast_mask_a_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2 + ) + tma_mcast_mask_b_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1 + ) + + return ( + tma_mcast_mask_a + | tma_mcast_mask_b + | tma_mcast_mask_a_peer + | tma_mcast_mask_b_peer + ) + + @staticmethod + def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): + """ + Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. + """ + bidx, bidy, _ = cute.arch.block_idx() + + mma_coord_vmnk = ( + bidx % cute.size(cta_layout_vmnk, mode=[0]), + bidx // cute.size(cta_layout_vmnk, mode=[0]), + bidy, + None, + ) + return mma_coord_vmnk[0] == 0 + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # No mcast mask if not using clusters + producer_mask = None + # All threadblocks are leaders if not using clusters + is_leader_cta = True + else: + producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk) + is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) + + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + consumer_mask = producer_mask + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + is_leader_cta, + cta_group, + ) + + def consumer_release(self, state: PipelineState): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(state.index, self.producer_mask), + ) + + def producer_commit(self, state: PipelineState): + """ + TMA producer commit is a noop since TMA instruction itself updates the transaction count. + """ + pass + + +@dataclass(frozen=True) +class PipelineAsyncUmma(PipelineAsync): + """ + PipelineAsyncUmma is used for AsyncThread producers and UMMA consumers (e.g. Blackwell input fusion pipelines). + """ + + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def _compute_leading_cta_rank(cta_v_size): + """ + Computes the leading CTA rank. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + return cta_rank_in_cluster // cta_v_size * cta_v_size + + @staticmethod + def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): + """ + Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. + """ + bidx, bidy, _ = cute.arch.block_idx() + mma_coord_vmnk = ( + bidx % cute.size(cta_layout_vmnk, mode=[0]), + bidx // cute.size(cta_layout_vmnk, mode=[0]), + bidy, + None, + ) + return mma_coord_vmnk[0] == 0 + + @staticmethod + def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout): + """ + Computes a mask for signaling arrivals to multicasting threadblocks. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + mask_self = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=0 + ) + block_in_cluster_coord_vmnk_peer = ( + cta_in_cluster_coord_vmnk[0] ^ 1, + *cta_in_cluster_coord_vmnk[1:], + ) + mask_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=0 + ) + return mask_self | mask_peer + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineAsyncUmma. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.AsyncThread + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), + num_stages, + producer, + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + cta_v_size = ( + cute.size(cta_layout_vmnk, mode=[0]) if cta_layout_vmnk is not None else 1 + ) + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: + # No mcast mask if we're not using 2CTA tcgen05 MMA + producer_mask = None + consumer_mask = None + else: + # If we're using 2CTA UMMAs, producer will arrive the mbar on leading CTA + # We need to get the target cta_rank + producer_mask = PipelineAsyncUmma._compute_leading_cta_rank(cta_v_size) + # consumer needs to get the mask to signal + consumer_mask = PipelineAsyncUmma._compute_peer_cta_mask(cta_layout_vmnk) + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineAsyncUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + cta_group, + ) + + def consumer_release(self, state: PipelineState): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group) + + +@dataclass(frozen=True) +class PipelineUmmaAsync(PipelineAsync): + """ + PipelineUmmaAsync is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines). + """ + + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout): + """ + Computes a mask to signal completion of tmem buffers for 2CTA kernels. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + return cute.make_layout_image_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mode=0 + ) + + @staticmethod + def _compute_peer_cta_rank(): + """ + Computes a mask to signal release of tmem buffers for 2CTA kernels. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + return cta_rank_in_cluster // 2 * 2 + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineUmmaAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TCGen05Mma + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # Set mask to None if not using clusters (i.e. 1CTA kernels) + producer_mask = None + else: + producer_mask = PipelineUmmaAsync._compute_tmem_sync_mask(cta_layout_vmnk) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: + # Set mask to None if not using 2CTA intructions + consumer_mask = None + else: + consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank() + + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineUmmaAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + cta_group, + ) + + def producer_commit(self, state: PipelineState): + """ + UMMA producer commit buffer full, cta_group needs to be provided. + """ + self.sync_object_full.arrive(state.index, self.producer_mask, self.cta_group) + + def producer_tail(self, state: PipelineState): + """ + Make sure the last used buffer empty signal is visible to producer. + Producer tail is usually executed by producer before exit, to avoid dangling + mbarrier arrive signals after kernel exit. + + :param state: The pipeline state that points to next useful buffer + :type state: PipelineState + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + + def then_body(): + # Assume state contains that next useful buffer + # So we only need to advance to num_stages - 1 times to last used buffer + for i in range(self.num_stages - 1): + state.advance() + self.producer_acquire(state) + + if_generate(is_leader_cta, then_body) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm90.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc19960c9b1ccca84dcc18bca002e2fa2a303ca --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm90.py @@ -0,0 +1,985 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from typing import Type, Tuple +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Union +import warnings + +import cutlass +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, Int32, if_generate + +from cutlass.pipeline import ( + Agent, + CooperativeGroup, + PipelineOp, + SyncObject, + MbarrierArray, + TmaStoreFence, + PipelineUserType, + PipelineState, + make_pipeline_state, + pipeline_init_wait, +) + +############################################################################## +# Pipeline classes +############################################################################## + + +@dataclass(frozen=True) +class PipelineAsync: + """PipelineAsync is a generic pipeline class where both the producer and consumer are + AsyncThreads. It also serves as a base class for specialized pipeline classes. + + This class implements a producer-consumer pipeline pattern where both sides operate + asynchronously. The pipeline maintains synchronization state using barrier objects + to coordinate between producer and consumer threads. + + The pipeline state transitions of one pipeline entry(mbarrier) can be represented as: + + .. table:: Pipeline State Transitions + :widths: auto + + +-----------+-----------+-----------+-----------+-----------+-----------+ + | Barrier | State | p.acquire | p.commit | c.wait | c.release | + +===========+===========+===========+===========+===========+===========+ + | empty_bar | empty | | n/a | n/a | - | + +-----------+-----------+-----------+-----------+-----------+-----------+ + | empty_bar | wait | | n/a | n/a | -> empty | + +-----------+-----------+-----------+-----------+-----------+-----------+ + | full_bar | wait | n/a | -> full | | n/a | + +-----------+-----------+-----------+-----------+-----------+-----------+ + | full_bar | full | n/a | - | | n/a | + +-----------+-----------+-----------+-----------+-----------+-----------+ + + Where: + + - p: producer + - c: consumer + - : This action is blocked until transition to a state allow it to proceed by other side + - e.g. ``p.acquire()`` is blocked until ``empty_bar`` transition to ``empty`` state by ``c.release()`` + + .. code-block:: text + + Array of mbarriers as circular buffer: + + Advance Direction + <------------------- + + Producer Consumer + | ^ + V | + +-----------------+ + --|X|X|W|D|D|D|D|R|X|<-. + / +-----------------+ \\ + | | + `------------------------' + + Where: + + - X: Empty buffer (initial state) + - W: Producer writing (producer is waiting for buffer to be empty) + - D: Data ready (producer has written data to buffer) + - R: Consumer reading (consumer is consuming data from buffer) + + **Example:** + + .. code-block:: python + + # Create pipeline with 5 stages + pipeline = PipelineAsync.create( + num_stages=5, # number of pipeline stages + producer_group=producer_warp, + consumer_group=consumer_warp + barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory + ) + + producer, consumer = pipeline.make_participants() + # Producer side + for i in range(num_iterations): + handle = producer.acquire_and_advance() # Wait for buffer to be empty & Move index to next stage + # Write data to pipeline buffer + handle.commit() # Signal buffer is full + + # Consumer side + for i in range(num_iterations): + handle = consumer.wait_and_advance() # Wait for buffer to be full & Move index to next stage + # Read data from pipeline buffer + handle.release() # Signal buffer is empty + """ + + sync_object_full: SyncObject + sync_object_empty: SyncObject + num_stages: int + producer_mask: Optional[Int32] + consumer_mask: Optional[Int32] + + @staticmethod + def _make_sync_object( + barrier_storage: cute.Pointer, + num_stages: int, + agent: tuple[PipelineOp, CooperativeGroup], + tx_count: int = 0, + ) -> SyncObject: + """ + Returns a SyncObject corresponding to an agent's PipelineOp. + """ + if agent[0] in [ + PipelineOp.AsyncThread, + PipelineOp.TmaLoad, + PipelineOp.TCGen05Mma, + PipelineOp.Composite, + PipelineOp.AsyncLoad, + ]: + return MbarrierArray( + barrier_storage=barrier_storage, + num_stages=num_stages, + agent=agent, + tx_count=tx_count, + ) + elif agent[0] is PipelineOp.TmaStore: + # Path taken for AsyncTmaStore + return TmaStoreFence(num_stages=num_stages) + else: + assert False, "Error: Invalid PipelineOp specified." + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + barrier_storage: cute.Pointer = None, + producer_mask: Int32 = None, + consumer_mask: Int32 = None, + ): + """Creates and initializes a new PipelineAsync instance. + + This helper function computes necessary attributes and returns an instance of PipelineAsync + with the specified configuration for producer and consumer synchronization. + + :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: int + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param producer_mask: Mask for signaling arrives for the producer agent, defaults to ``None`` + :type producer_mask: Int32, optional + :param consumer_mask: Mask for signaling arrives for the consumer agent, defaults to ``None`` + :type consumer_mask: Int32, optional + :return: A new PipelineAsync instance + :rtype: PipelineAsync + :raises ValueError: If barrier_storage is not a cute.Pointer instance + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.AsyncThread + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + pipeline_init_wait() + + return PipelineAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + ) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + + def producer_try_acquire(self, state: PipelineState): + return self.sync_object_empty.try_wait(state.index, state.phase) + + def producer_commit(self, state: PipelineState): + self.sync_object_full.arrive(state.index, self.producer_mask) + + def consumer_wait( + self, state: PipelineState, try_wait_token: Optional[Boolean] = None + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(state.index, state.phase), + ) + + def consumer_try_wait(self, state: PipelineState): + return self.sync_object_full.try_wait(state.index, state.phase) + + def consumer_release(self, state: PipelineState): + self.sync_object_empty.arrive(state.index, self.consumer_mask) + + def producer_get_barrier(self, state: PipelineState) -> cute.Pointer: + return self.sync_object_full.get_barrier(state.index) + + def producer_tail(self, state: PipelineState): + """ + Make sure the last used buffer empty signal is visible to producer. + Producer tail is usually executed by producer before exit, to avoid dangling + mbarrier arrive signals after kernel exit. + + :param state: The pipeline state that points to next useful buffer + :type state: PipelineState + """ + # Assume state contains that next useful buffer + # So we only need to advance to num_stages - 1 times to last used buffer + for i in range(self.num_stages - 1): + state.advance() + self.producer_acquire(state) + + # Util methods to manage produer and consumer + def make_producer(self): + state = make_pipeline_state(PipelineUserType.Producer, self.num_stages) + return PipelineProducer(self, state, self.sync_object_full.cg) + + def make_consumer(self): + state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages) + return PipelineConsumer(self, state, self.sync_object_empty.cg) + + def make_participants(self): + return self.make_producer(), self.make_consumer() + + + +@dataclass(frozen=True) +class PipelineCpAsync(PipelineAsync): + """ + PipelineCpAsync is used for CpAsync producers and AsyncThread consumers (e.g. Hopper non-TMA mainloops). + """ + + @staticmethod + def create( + barrier_storage: cute.Pointer, + num_stages: Int32, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + producer_mask: Int32 = None, + consumer_mask: Int32 = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: CooperativeGroup for the consumer agent + :type consumer_group: CooperativeGroup + :param producer_mask: Mask for signaling arrives for the producer agent + :type producer_mask: Int32 | None + :param consumer_mask: Mask for signaling arrives for the consumer agent + :type consumer_mask: Int32 | None + """ + producer_type = PipelineOp.AsyncLoad + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_array_full = PipelineCpAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer + ) + sync_object_array_empty = PipelineCpAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + pipeline_init_wait() + + return PipelineCpAsync( + sync_object_array_full, + sync_object_array_empty, + num_stages, + producer_mask, + consumer_mask, + ) + + +@dataclass(frozen=True) +class PipelineTmaAsync(PipelineAsync): + """ + PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops). + """ + + is_signalling_thread: Boolean + + @staticmethod + @cute.jit + def init_empty_barrier_arrive_signal(cta_layout_vmnk: cute.Layout, tidx: Int32): + """ + Initialize the empty barrier arrive signal + This function returns the destination cta rank and a boolean indicating if the signalling thread is the same as the current thread + """ + # Logic to optimally schedule Empty Arrives + cluster_shape_vmnk = cta_layout_vmnk.shape + + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + + tidx = tidx % 32 + is_signalling_thread = tidx < cute.size(cluster_shape_vmnk) + dst_rank = tidx % cute.size(cluster_shape_vmnk) + + dst_cta_coord = cta_layout_vmnk.get_hier_coord(dst_rank) + cur_cta_coord = cta_layout_vmnk.get_hier_coord(cta_rank_in_cluster) + + is_same_row = ( + dst_cta_coord[0] == cur_cta_coord[0] + and dst_cta_coord[1] == cur_cta_coord[1] + and dst_cta_coord[3] == cur_cta_coord[3] + ) + is_same_col = ( + dst_cta_coord[0] == cur_cta_coord[0] + and dst_cta_coord[2] == cur_cta_coord[2] + and dst_cta_coord[3] == cur_cta_coord[3] + ) + + is_same_row_or_col = is_same_row or is_same_col + is_signalling_thread_final = is_signalling_thread and is_same_row_or_col + + return dst_rank, is_signalling_thread_final + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + tidx: Optional[Int32] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + :param tidx: thread index to consumer async threads + :type tidx: Int32 | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + if tidx is None: + tidx, _, _ = cute.arch.thread_idx() + if cta_layout_vmnk is None: + cta_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + ( + dst_rank, + is_signalling_thread, + ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx) + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + dst_rank = None + else: + dst_rank = dst_rank + + producer_mask = None + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + dst_rank, + is_signalling_thread, + ) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + self.sync_object_full.arrive(state.index, self.producer_mask) + + def producer_commit(self, state: PipelineState): + """ + TMA producer commit is a noop since TMA instruction itself updates the transaction count. + """ + pass + + def consumer_release(self, state: PipelineState): + """ + TMA consumer release conditionally signals the empty buffer to the producer. + """ + if_generate( + self.is_signalling_thread, + lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), + ) + + +@dataclass(frozen=True) +class PipelineTmaMultiConsumersAsync(PipelineAsync): + """ + PipelineTmaMultiConsumersAsync is used for TMA producers and UMMA+Async consumers. + """ + + is_leader_cta: bool + sync_object_empty_umma: SyncObject + sync_object_empty_async: SyncObject + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group_umma: CooperativeGroup, + consumer_group_async: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaMultiConsumersAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group_umma: `CooperativeGroup` for the UMMA consumer agent + :type consumer_group_umma: CooperativeGroup + :param consumer_group_async: `CooperativeGroup` for the AsyncThread consumer agent + :type consumer_group_async: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.Composite + consumer_type_umma = PipelineOp.TCGen05Mma + consumer_type_async = PipelineOp.AsyncThread + + if consumer_group_umma.agent != consumer_group_async.agent: + raise ValueError( + "UMMA and AsyncThread consumer groups must be the same agent" + ) + + if cta_layout_vmnk is not None and cute.size(cta_layout_vmnk) != 1: + raise ValueError( + f"PipelineTmaMultiConsumersAsync is not verified for cta_layout_vmnk != 1, cta_layout_vmnk:{cta_layout_vmnk}" + ) + + consumer_group = CooperativeGroup( + consumer_group_umma.agent, + consumer_group_umma.size + consumer_group_async.size, + ) + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + sync_object_empty_umma = sync_object_empty.recast_to_new_op_type( + consumer_type_umma + ) + sync_object_empty_async = sync_object_empty.recast_to_new_op_type( + consumer_type_async + ) + + # No mcast mask if not using clusters + producer_mask = None + consumer_mask = None + # All threadblocks are leaders if not using clusters + is_leader_cta = True + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaMultiConsumersAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + is_leader_cta, + sync_object_empty_umma, + sync_object_empty_async, + cta_group, + ) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + """ + TMA producer acquire waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(state.index, self.producer_mask), + ) + + def producer_commit(self, state: PipelineState): + """ + TMA producer commit is a noop since TMA instruction itself updates the transaction count. + """ + pass + + def consumer_release(self, state: PipelineState, op_type: PipelineOp): + if op_type == PipelineOp.TCGen05Mma: + self.sync_object_empty_umma.arrive( + state.index, self.consumer_mask, self.cta_group + ) + elif op_type == PipelineOp.AsyncThread: + self.sync_object_empty_async.arrive(state.index, self.consumer_mask) + else: + raise ValueError(f"Invalid PipelineOp specified. op_type:{op_type}") + + +@dataclass(frozen=True) +class PipelineTmaStore(PipelineAsync): + """ + PipelineTmaStore is used for synchronizing TMA stores in the epilogue. It does not use mbarriers. + """ + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaStore. + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + """ + + producer_type = PipelineOp.TmaStore + + producer = (producer_type, producer_group) + + sync_object_full = PipelineAsync._make_sync_object(None, num_stages, producer) + + return PipelineTmaStore(sync_object_full, None, num_stages, None, None) + + def producer_acquire(self): + self.sync_object_full.wait() + + def producer_commit(self): + self.sync_object_full.arrive() + + def consumer_wait(self): + assert False, "Error: PipelineTmaStore does not have a consumer agent." + + def consumer_release(self): + assert False, "Error: PipelineTmaStore does not have a consumer agent." + + def producer_tail(self): + self.sync_object_full.tail() + + +################################################################# +# Utilities to help user of pipeline to simplify the workflow +################################################################# + + +class ImmutableResourceHandle: + __origin: PipelineAsync + __immutable_state: PipelineState + + def __init__(self, origin: PipelineAsync, immutable_state: PipelineState): + self.__origin = origin + self.__immutable_state = immutable_state + + @property + def index(self): + """Get the index of the current pipeline stage.""" + return self.__immutable_state.index + + @property + def count(self): + """Get the count of how many handles this producer has committed. + This is useful for tracking the number of blocks that have been loaded from gmem. + """ + return self.__immutable_state.count + + def get_origin(self): + """Get the original pipeline this resource handle belongs to.""" + return self.__origin + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + # TODO: need to handle pipeline as well + return self.__immutable_state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Producer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Producer instance with state initialized from values + :rtype: Producer + """ + return self.__class__( + self.__origin, self.__immutable_state.__new_from_mlir_values__(values) + ) + +class PipelineProducer: + """A class representing a producer in an asynchronous pipeline. + + The Producer class manages the producer side of an asynchronous pipeline, handling + synchronization and state management for producing data. It provides methods for + acquiring, committing, and advancing through pipeline stages. + + :ivar __pipeline: The asynchronous pipeline this producer belongs to + :type __pipeline: PipelineAsync + :ivar __state: The current state of the producer in the pipeline + :type __state: PipelineState + :ivar __group: The cooperative group this producer operates in + :type __group: CooperativeGroup + + **Examples:** + + .. code-block:: python + + pipeline = PipelineAsync.create(...) + producer = pipeline.create_producer(producer_group, stages) + for i in range(iterations): + handle = producer.acquire_and_advance() # Wait for buffer to be empty + # Produce data + producer.commit(handle) # Signal data is ready + # An alternative way to do this is: + # handle.commit() # Signal data is ready + """ + + __pipeline: PipelineAsync + __state: PipelineState + __group: CooperativeGroup + + class ImmutableResourceHandle(ImmutableResourceHandle): + @property + def barrier(self): + """Get the barrier pointer for the current pipeline stage. + + :return: Pointer to the barrier for the current stage + :rtype: cute.Pointer + """ + return self.get_origin().producer_get_barrier( + self._ImmutableResourceHandle__immutable_state + ) + + def commit(self): + """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + """ + self.get_origin().producer_commit( + self._ImmutableResourceHandle__immutable_state + ) + + def __init__(self, pipeline, state, group: CooperativeGroup): + """Initialize a new Producer instance. + + :param pipeline: The pipeline this producer belongs to + :type pipeline: PipelineAsync + :param state: Initial pipeline state + :type state: PipelineState + :param group: The cooperative group for synchronization + :type group: CooperativeGroup + """ + self.__pipeline = pipeline + self.__state = state + self.__group = group + + def acquire( + self, + try_acquire_token: Optional[Boolean] = None, + ) -> ImmutableResourceHandle: + """Wait for the current buffer to be empty before producing data. + This is a blocking operation. + + :param try_acquire_token: Optional token to try to acquire the buffer + :type try_acquire_token: Optional[Boolean] + :return: A handle to the producer for committing the data + :rtype: ImmutableResourceHandle + """ + self.__pipeline.producer_acquire(self.__state, try_acquire_token) + handle = PipelineProducer.ImmutableResourceHandle( + self.__pipeline, self.__state.clone() + ) + return handle + + def advance(self): + """Move to the next pipeline stage.""" + self.__state.advance() + + def acquire_and_advance( + self, try_acquire_token: Optional[Boolean] = None + ) -> ImmutableResourceHandle: + """Wait for the current buffer to be empty before producing data. + Then advance to the next stage. + This is a blocking operation. + + :param try_acquire_token: Optional token to try to acquire the buffer + :type try_acquire_token: Optional[Boolean] + :return: A handle to the producer for committing the data + :rtype: ImmutableResourceHandle + """ + handle = self.acquire(try_acquire_token) + self.advance() + return handle + + def try_acquire(self) -> Boolean: + """Try to acquire the current buffer without blocking. + + :return: True if acquisition was successful, False otherwise + :rtype: Boolean + """ + return self.__pipeline.producer_try_acquire(self.__state) + + def commit(self, handle: Optional[ImmutableResourceHandle] = None): + """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + """ + if handle is not None: + assert ( + handle.get_origin() is self + ), "ResourceHandle does not belong to this PipelineProducer instance" + handle.commit() + else: + self.__pipeline.producer_commit(self.__state) + + def tail(self): + """Ensure all used buffers are properly synchronized before producer exit. + This should be called before the producer finishes to avoid dangling signals. + """ + self.__pipeline.producer_tail(self.__state) + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + # TODO: need to handle pipeline as well + return self.__state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Producer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Producer instance with state initialized from values + :rtype: Producer + """ + return PipelineProducer( + self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group + ) + +class PipelineConsumer: + """A class representing a consumer in an asynchronous pipeline. + + The Consumer class manages the consumer side of an asynchronous pipeline, handling + synchronization and state management for consuming data. It provides methods for + waiting, releasing, and advancing through pipeline stages. + + :ivar __pipeline: The asynchronous pipeline this consumer belongs to + :type __pipeline: PipelineAsync + :ivar __state: The current state of the consumer in the pipeline + :type __state: PipelineState + :ivar __group: The cooperative group this consumer operates in + :type __group: CooperativeGroup + + **Examples:** + .. code-block:: python + + pipeline = PipelineAsync.create(...) + consumer = pipeline.create_consumer(consumer_group, stages) + for i in range(iterations): + handle = consumer.wait_and_advance() # Wait for data to be ready + # Consume data + consumer.release(handle) # Signal buffer is empty + # An alternative way to do this is: + # handle.release() # Signal buffer is empty + """ + + __pipeline: PipelineAsync + __state: PipelineState + __group: CooperativeGroup + + class ImmutableResourceHandle(ImmutableResourceHandle): + def release(self): + """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + """ + self.get_origin().consumer_release( + self._ImmutableResourceHandle__immutable_state + ) + + def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup): + """Initialize a new Consumer instance. + + :param pipeline: The pipeline this consumer belongs to + :type pipeline: PipelineAsync + :param state: Initial pipeline state + :type state: PipelineState + :param group: The cooperative group for synchronization + :type group: CooperativeGroup + """ + self.__pipeline = pipeline + self.__group = group + self.__state = state + + def wait(self, try_wait_token: Optional[Boolean] = None) -> ImmutableResourceHandle: + """Wait for data to be ready in the current buffer. + This is a blocking operation. + + :param try_wait_token: Optional token to try to wait for the buffer + :type try_wait_token: Optional[Boolean] + :return: A handle to the consumer for releasing the data + :rtype: PipelineConsumerHandle + """ + self.__pipeline.consumer_wait(self.__state, try_wait_token) + handle = PipelineConsumer.ImmutableResourceHandle( + self.__pipeline, self.__state.clone() + ) + return handle + + def advance(self): + """Move to the next pipeline stage.""" + self.__state.advance() + + def wait_and_advance( + self, try_wait_token: Optional[Boolean] = None + ) -> ImmutableResourceHandle: + """Wait for data to be ready in the current buffer. + Then advance to the next stage. + This is a blocking operation. + + :param try_wait_token: Optional token to try to wait for the buffer + :type try_wait_token: Optional[Boolean] + :return: A handle to the consumer for releasing the data + :rtype: PipelineConsumerHandle + """ + handle = self.wait(try_wait_token) + self.advance() + return handle + + def try_wait(self) -> Boolean: + """Try to check if data is ready without blocking. + + :return: True if data is ready, False otherwise + :rtype: Boolean + """ + return self.__pipeline.consumer_try_wait(self.__state) + + def release(self, handle: Optional[ImmutableResourceHandle] = None): + """Signal that data consumption is complete for the current stage. + This allows producers to start producing new data. + """ + if handle is not None: + assert ( + handle.get_origin() is self + ), "ResourceHandle does not belong to this PipelineConsumer instance" + handle.release() + else: + self.__pipeline.consumer_release(self.__state) + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + return self.__state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Consumer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Consumer instance with state initialized from values + :rtype: Consumer + """ + # TODO: need to call pipeline.__new_from_mlir_values__ recursively + return PipelineConsumer( + self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group + ) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/torch.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ee5777cad35487f30b8705ff19747405d11194 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/torch.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import ctypes +from math import prod +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Type, Union + +from cutlass.cute.typing import ( + Numeric, + Boolean, + Float, + Integer, + TFloat32, + Float8E4M3B11FNUZ, + Float8E4M3FN, + Float8E5M2, + Float8E8M0FNU, + Float4E2M1FN, + Tensor, +) +from cutlass.cute.runtime import from_dlpack +import cutlass.cute as cute +import torch +import cuda.bindings.driver as cuda + + +def dtype(ty: Type[Numeric]): + """ + Return the corresponding torch.dtype per the given DSL type + """ + torch_dtype = getattr(torch, ty.__name__.lower(), None) + + torch_type_map = { + Boolean: torch.bool, + # TFloat32 is just alias of float32 + TFloat32: torch.float32, + Float8E5M2: torch.float8_e5m2, + Float8E4M3FN: torch.float8_e4m3fn, + Float8E4M3B11FNUZ: torch.float8_e4m3fnuz, + } + if torch_dtype is None: + torch_dtype = torch_type_map.get(ty) + + if torch_dtype is None: + raise TypeError(f"{ty} is not supported by torch") + return torch_dtype + + +def as_tensor(pointer, shape, torch_type): + """Convert a pointer to a torch tensor""" + if torch_type.itemsize == 1: + cytype = ctypes.c_uint8 + elif torch_type.itemsize == 2: + cytype = ctypes.c_uint16 + elif torch_type.itemsize == 4: + cytype = ctypes.c_uint32 + elif torch_type.itemsize == 8: + cytype = ctypes.c_uint64 + else: + raise ValueError(f"Unsupported torch dtype: {torch_type}") + cpointer = ctypes.cast(pointer, ctypes.POINTER(cytype)) + arr = (cpointer._type_ * prod(shape)).from_address( + ctypes.addressof(cpointer.contents) + ) + return torch.frombuffer(arr, dtype=torch_type).view(*shape) + + +@dataclass +class ScalarInitConfig: + """Configuration for scalar initialization""" + + value: float = 0.0 + + +@dataclass +class RandomInitConfig: + """Configuration for random initialization""" + + min_val: int = -2 + max_val: int = 2 + + +@dataclass +class GaussianInitConfig: + """Configuration for Gaussian initialization""" + + mean: float = 0.0 + std: float = 1.0 + scale: float = 1.0 + + +class TensorInitType(Enum): + """Enumeration of tensor initialization types""" + + SKIP = "skip" + SCALAR = "scalar" + RANDOM = "random" + GAUSSIAN = "gaussian" + + +def create_and_permute_torch_tensor( + shape, + dtype: "torch.dtype", + permute_order=None, + init_type: TensorInitType = TensorInitType.RANDOM, + init_config: Optional[ + Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig] + ] = None, + device: Optional[torch.device] = None, +) -> "torch.Tensor": + """ + Create a torch tensor with specified shape and dtype. Optionally permute it and initialize it with specified init type and config + """ + init_dtype = torch.int32 if init_type == TensorInitType.RANDOM else torch.float32 + init_torch_tensor = torch.empty(*shape, dtype=init_dtype, device=device) + if init_type == TensorInitType.SKIP: + assert init_config is None + f32_torch_tensor = init_torch_tensor + elif init_type == TensorInitType.SCALAR: + if init_config is None: + init_config = ScalarInitConfig() + else: + if not isinstance(init_config, ScalarInitConfig): + raise ValueError("init_config must be ScalarInitConfig()") + f32_torch_tensor = init_torch_tensor.fill_(init_config.value) + elif init_type == TensorInitType.RANDOM: + if init_config is None: + init_config = RandomInitConfig() + else: + if not isinstance(init_config, RandomInitConfig): + raise ValueError("init_config must be RandomInitConfig()") + f32_torch_tensor = init_torch_tensor.random_( + init_config.min_val, init_config.max_val + ).to(dtype=torch.float32) + elif init_type == TensorInitType.GAUSSIAN: + if init_config is None: + init_config = GaussianInitConfig() + else: + if not isinstance(init_config, GaussianInitConfig): + raise ValueError("init_config must be GaussianInitConfig()") + f32_torch_tensor = init_torch_tensor.normal_(init_config.mean, init_config.std) + f32_torch_tensor = f32_torch_tensor * init_config.scale + else: + raise ValueError(f"Invalid init type: {init_type}") + + if permute_order is not None: + f32_torch_tensor = f32_torch_tensor.permute(permute_order) + + dtype_torch_tensor = f32_torch_tensor.to(dtype=dtype) + + return dtype_torch_tensor + + +def convert_cute_tensor( + f32_torch_tensor: "torch.Tensor", + cute_tensor: Tensor, + dtype: Type[Numeric], + is_dynamic_layout: bool = True, +) -> Tensor: + """ + Change the value of the cute tensor to make its value converted from a fp32 torch tensor. + Used for fp8 types tensor creatation now. + """ + # if torch_tensor is on cpu, create a gpu copy + if f32_torch_tensor.device.type == "cpu": + f32_torch_tensor = f32_torch_tensor.cuda() + + # Fp8 type need explicit type conversion + if dtype in { + Float8E5M2, + Float8E4M3FN, + Float8E8M0FNU, + Float4E2M1FN, + }: + fp32_cute_tensor = from_dlpack(f32_torch_tensor) + if is_dynamic_layout: + fp32_cute_tensor = fp32_cute_tensor.mark_layout_dynamic( + f32_torch_tensor.dim_order()[-1] + ) + # Copy and convert from f32 cute tensor to dtype cute tensor + cute.testing.convert(fp32_cute_tensor, cute_tensor) + return cute_tensor + + +def default_stream() -> cuda.CUstream: + """ + Get default CUstream from torch stream + """ + torch_stream = torch.cuda.default_stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + return stream + + +def current_stream() -> cuda.CUstream: + """ + Get current CUstream from torch stream + """ + torch_stream = torch.cuda.current_stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + return stream + + +def matrix( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + cutlass_dtype: Type[Numeric], + init_type: TensorInitType = TensorInitType.RANDOM, + init_config: Optional[ + Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig] + ] = None, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """ + Create a torch tensor for matrix + + :param l: length of the matrix + :param mode0: mode0 of the matrix + :param mode1: mode1 of the matrix + :param is_mode0_major: whether the matrix is mode0 major + :param cutlass_dtype: cutlass dtype of the matrix + :param init_type: type of initialization + :param init_config: configuration for initialization + :param device: target torch device + """ + + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + + if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + torch_dtype = torch.int8 + else: + torch_dtype = dtype(cutlass_dtype) + + if init_type == TensorInitType.RANDOM and init_config is None: + if torch_dtype.is_signed: + min_val = -2 + max_val = 2 + else: + min_val = 0 + max_val = 4 + init_config = RandomInitConfig(min_val=min_val, max_val=max_val) + + # Create dtype torch tensor + torch_tensor = create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=init_type, + init_config=init_config, + device=device, + ) + + return torch_tensor + + +def cute_tensor_like( + data_ref: torch.Tensor, + cutlass_dtype: Type[Numeric], + is_dynamic_layout: bool, + assumed_align: Optional[int] = None, +) -> tuple[Tensor, torch.Tensor]: + """ + Create a cute tensor use a torch tensor as the data source + + :param data_ref: torch tensor as the data source + :param cutlass_dtype: cutlass dtype of the cute tensor + :param is_dynamic_layout: whether the cute tensor uses dynamic layout + :param assumed_align: assumed alignment of the cute tensor + """ + + # allocate device buffer for cute tensor + if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + torch_dtype = torch.int8 + else: + torch_dtype = dtype(cutlass_dtype) + torch_tensor = torch.empty_like(data_ref, dtype=torch_dtype, device="cuda") + + # create cute tensor using the device buffer + cute_tensor = from_dlpack(torch_tensor, assumed_align=assumed_align) + cute_tensor.element_type = cutlass_dtype + if is_dynamic_layout: + for i, stride in enumerate(torch_tensor.stride()): + if stride == 1: + leading_dim = i + break + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + + # initialize the cute tensor data + if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + cute_tensor = convert_cute_tensor( + data_ref.to(dtype=torch.float32), + cute_tensor, + cutlass_dtype, + is_dynamic_layout, + ) + else: + torch_tensor.copy_(data_ref.to(dtype=torch_dtype)) + + return cute_tensor, torch_tensor diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aec0a186d7a8fc18d65637e97905c7cd5702310d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/__init__.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .static_persistent_tile_scheduler import ( + WorkTileInfo, + PersistentTileSchedulerParams, + StaticPersistentTileScheduler, +) + +from .hardware_info import ( + HardwareInfo, +) + +from .blackwell_helpers import ( + compute_epilogue_tile_shape, + get_smem_store_op, + get_tmem_load_op, + get_num_tmem_alloc_cols, + make_smem_layout_a, + make_smem_layout_b, + make_smem_layout_epi, + make_trivial_tiled_mma, + make_blockscaled_trivial_tiled_mma, +) + +from .hopper_helpers import ( + sm90_get_smem_store_op, +) + +from .blockscaled_layout import ( + BlockScaledBasicChunk, + tile_atom_to_shape_SF, + make_smem_layout_sfa, + make_smem_layout_sfb, + make_tmem_layout_sfa, + make_tmem_layout_sfb, +) + +from .grouped_gemm_tile_scheduler_helper import ( + GroupSearchResult, + GroupedGemmGroupSearchState, + GroupedGemmTileSchedulerHelper, + create_initial_search_state, +) + +from .tensormap_manager import ( + TensorMapUpdateMode, + TensorMapManager, +) + +from .smem_allocator import SmemAllocator + +from .layout import LayoutEnum + +from .smem_capacity import ( + get_smem_capacity_in_bytes, +) + +from .distributed_helpers import ( + spin_lock_wait, + spin_lock_multimem_arrive, + multimem_ld_reduce_8xf16, + multimem_ld_reduce_4xf32, + multimem_ld_reduce_8xbf16, + multimem_ld_reduce_16xe4m3, + multimem_ld_reduce_16xe5m2, + multimem_st_4xb32, + sm_wise_inter_gpu_multimem_barrier, +) + +__all__ = [ + "get_smem_capacity_in_bytes", + "SmemAllocator", + "LayoutEnum", + "WorkTileInfo", + "PersistentTileSchedulerParams", + "StaticPersistentTileScheduler", + "TensorMapUpdateMode", + "TensorMapManager", + "GroupSearchResult", + "GroupedGemmGroupSearchState", + "create_initial_search_state", + "GroupedGemmTileSchedulerHelper", + "HardwareInfo", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/ampere_helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/ampere_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..1341756f3584f89b0c201631445beb91c34dc29e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/ampere_helpers.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from enum import Enum +from typing_extensions import deprecated +import warnings + + +@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") +class SmemCapacity(Enum): + SM80_SMEM_CAPACITY_BYTES = (164 - 1) * 1024 + SM86_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 + SM89_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 + + +warnings.warn( + "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", + DeprecationWarning, + stacklevel=2, +) +# Dictionary to map compute capability to SMEM capacity +SMEM_CAPACITY = { + "sm80": SmemCapacity.SM80_SMEM_CAPACITY_BYTES.value, + "sm86": SmemCapacity.SM86_SMEM_CAPACITY_BYTES.value, + "sm89": SmemCapacity.SM89_SMEM_CAPACITY_BYTES.value, +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blackwell_helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blackwell_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb6bf4dbfa3e73f058037e79b0999697d720502 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blackwell_helpers.py @@ -0,0 +1,1135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from enum import Enum +from math import log2, ceil +from typing import List, Type, Union, Tuple +from typing_extensions import deprecated +import warnings + +from cutlass.cutlass_dsl import ( + Float16, + BFloat16, + TFloat32, + Float32, + Uint8, + Int8, + Float8E4M3FN, + Float8E5M2, + Float4E2M1FN, + Numeric, + NumericMeta, + dsl_user_op, +) +import cutlass.cute as cute +from cutlass.cute.nvgpu.common import CopyUniversalOp +from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp, StMatrix16x8x8bOp +from cutlass.cute.nvgpu.tcgen05 import ( + MmaF16BF16Op, + MmaTF32Op, + MmaI8Op, + MmaFP8Op, + MmaMXF8Op, + MmaMXF4Op, + MmaMXF4NVF4Op, + OperandSource, + OperandMajorMode, + CtaGroup, + Ld16x64bOp, + Ld16x128bOp, + Ld16x256bOp, + Ld16x32bx2Op, + Ld32x32bOp, + Repetition, + Pack, + find_tmem_tensor_col_offset, + SmemLayoutAtomKind, + make_smem_layout_atom, + tile_to_mma_shape, + is_tmem_load, + get_tmem_copy_properties, +) +from cutlass.cute.nvgpu.cpasync import ( + CopyBulkTensorTileG2SMulticastOp, + CopyBulkTensorTileG2SOp, +) +from cutlass.utils.layout import LayoutEnum + + +@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") +class SmemCapacity(Enum): + SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 + SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 + + +warnings.warn( + "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", + DeprecationWarning, + stacklevel=2, +) +# Dictionary to map compute capability to SMEM capacity +SMEM_CAPACITY = { + "sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value, + "sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value, +} + + +@dsl_user_op +def compute_epilogue_tile_shape( + cta_tile_shape: cute.Shape, + use_2cta_instrs: bool, + layout_d: LayoutEnum, + elem_ty_d: Type[Numeric], + *, + layout_c: LayoutEnum = None, + elem_ty_c: Union[Type[Numeric], None] = None, + loc=None, + ip=None, +) -> cute.Tile: + """Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. + + :param cta_tile_shape: A tuple or list representing the dimensions of the CTA tile, where + cta_tile_shape[0] corresponds to the height (M) and cta_tile_shape[1] + corresponds to the width (N) of the tile. + :type cta_tile_shape: cute.Shape + :param use_2cta_instrs: A flag indicating whether the configuration is for a 2SM setup. + :type use_2cta_instrs: bool + :param layout_d: The layout enum of the output tensor D. + :type layout_d: LayoutEnum + :param elem_ty_d: The element type of output tensor D. + :type elem_ty_d: Type[Numeric] + :param layout_c: The layout enum of the input tensor C. Defaults to None. + :type layout_c: LayoutEnum, optional + :param elem_ty_c: The element type for input tensor C. Defaults to None. + :type elem_ty_c: Union[Type[Numeric], None], optional + + :return: Returns epilog tiler, which is used in subsequent epilog partitions. + :rtype: cute.Tile + + :raises ValueError: If the computed tile cute.size does not meet minimum requirements based on CTA dimensions. + """ + + def validate_type(ty, ty_name): + if not isinstance(ty, NumericMeta): + raise TypeError(f"{ty_name} must be Numeric, but got {ty}") + + validate_type(elem_ty_d, "elem_ty_d") + if elem_ty_c is not None: + validate_type(elem_ty_c, "elem_ty_c") + + cta_m, cta_n = cta_tile_shape[:2] + (warp_m, warp_n) = (2, 2) if (cta_m == 64 and use_2cta_instrs) else (4, 1) + disable_source = elem_ty_c == None + max_bits = ( + elem_ty_d.width if disable_source else max(elem_ty_c.width, elem_ty_d.width) + ) + + dp_full = 32 + tile_m = min(cta_m, dp_full * warp_m) + n_perf = 0 + if disable_source: + if max_bits == 4: + compute_elts = 8192 + else: + compute_elts = 4096 + n_perf = compute_elts // tile_m + else: + if max_bits == 32: + n_perf = 16 if (cta_m > 64 and cta_n <= 128) else 32 + elif max_bits == 16: + n_perf = 32 if cta_n <= 128 else 64 + else: + n_perf = 64 + + d_is_m_major = layout_d.is_m_major_c() + c_is_m_major = True if layout_c is None else layout_c.is_m_major_c() + + n_min_d = ( + 8 * warp_n + if d_is_m_major + else (128 * warp_n if elem_ty_d.width == 6 else 128 // elem_ty_d.width * warp_n) + ) + n_min_c = ( + 8 * warp_n + if (c_is_m_major or disable_source) + else (128 * warp_n if elem_ty_c.width == 6 else 128 // elem_ty_c.width * warp_n) + ) + tile_n = min(cta_n, max(n_perf, n_min_c, n_min_d)) + + if cta_n < n_min_c or cta_n < n_min_d: + raise ValueError(f"CTA tile too small: {cta_tile_shape=}") + + # stride by tmem warp layout and return a by-mode tiler + tile_m_layout = cute.make_layout(tile_m, loc=loc, ip=ip) + tile_n_layout = cute.make_layout( + (tile_n // warp_n, warp_n), stride=(1, cta_n // warp_n), loc=loc, ip=ip + ) + return (tile_m_layout, cute.coalesce(tile_n_layout, loc=loc, ip=ip)) + + +@dsl_user_op +def get_smem_store_op( + layout_d: LayoutEnum, + elem_ty_d: Type[Numeric], + elem_ty_acc: Type[Numeric], + tiled_tmem_load: cute.TiledCopy, + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """Selects the largest vectorized smem store atom available subject to + constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership. + + :param layout_d: The layout enum of the output tensor D. + :type layout_d: LayoutEnum + :param elem_ty_d: The element type for output tensor D. + :type elem_ty_d: Type[Numeric] + :param elem_ty_acc: The element type for accumulator. + :type elem_ty_acc: Type[Numeric] + :param tiled_tmem_load: An instance of TiledCopy that represents the tmem load operation. + :type tiled_tmem_load: cute.TiledCopy + + :return: Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters. + :rtype: cute.CopyAtom + """ + + def validate_type(ty, ty_name): + if not isinstance(ty, NumericMeta): + raise TypeError(f"{ty_name} must be a Numeric, but got {ty}") + + validate_type(elem_ty_d, "elem_ty_d") + validate_type(elem_ty_acc, "elem_ty_acc") + + is_m_major = layout_d.is_m_major_c() + is_n_major = layout_d.is_n_major_c() + + if not is_tmem_load(tiled_tmem_load): + return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip) + + num_dp, num_bits, num_rep, pack = get_tmem_copy_properties(tiled_tmem_load) + + use_stmatrix_m8n8_4x = ( + all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 32, + is_n_major, + num_dp == 16, + num_bits == 128, + num_rep in (2, 4, 8, 16, 32, 64), + pack == Pack.NONE, + ] + ) + or all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 16, + num_dp == 16, + num_bits == 256, + num_rep in (2, 4, 8, 16, 32), + pack == Pack.NONE, + ] + ) + or all( + [ + elem_ty_acc.width == 16, + elem_ty_d.width == 16, + num_dp == 16, + num_bits == 128, + num_rep in (2, 4, 8, 16, 32, 64), + pack == Pack.PACK_16b_IN_32b, + ] + ) + ) + use_stmatrix_m16n8_4x = all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 8, + is_m_major, + num_dp == 16, + num_bits == 256, + num_rep in (4, 8, 16, 32), + pack == Pack.NONE, + ] + ) + use_stmatrix_m8n8_2x = ( + all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 32, + is_n_major, + num_dp == 16, + num_bits == 128, + num_rep == 1, + pack == Pack.NONE, + ] + ) + or all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 16, + num_dp == 16, + num_bits == 256, + num_rep == 1, + pack == Pack.NONE, + ] + ) + or all( + [ + elem_ty_acc.width == 16, + elem_ty_d.width == 16, + num_dp == 16, + num_bits == 128, + num_rep == 1, + pack == Pack.PACK_16b_IN_32b, + ] + ) + ) + use_stmatrix_m16n8_2x = all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 8, + is_m_major, + num_dp == 16, + num_bits == 256, + num_rep == 2, + pack == Pack.NONE, + ] + ) + use_stmatrix_m16n8_1x = all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 8, + is_m_major, + num_dp == 16, + num_bits == 256, + num_rep == 1, + pack == Pack.NONE, + ] + ) + + if use_stmatrix_m8n8_4x: + op = StMatrix8x8x16bOp(is_m_major, 4) + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + elif use_stmatrix_m8n8_2x: + op = StMatrix8x8x16bOp(is_m_major, 2) + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + elif use_stmatrix_m16n8_4x: + op = StMatrix16x8x8bOp(4) + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + elif use_stmatrix_m16n8_2x: + op = StMatrix16x8x8bOp(2) + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + elif use_stmatrix_m16n8_1x: + op = StMatrix16x8x8bOp(1) + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + else: + op = CopyUniversalOp() + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + + +@dsl_user_op +def get_tmem_load_op( + cta_tile_shape: cute.Shape, + layout_d: LayoutEnum, + elem_ty_d: Type[Numeric], + elem_ty_acc: Type[Numeric], + epi_tile: cute.Tile, + use_2cta_instrs: bool, + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """Finds a performant TMEM_LOAD copy op for the selected epilogue + tile (epi_tile), element types, and tcgen05.mma instruction used. + + :param cta_tile_shape: A tuple or list representing the dimensions of the CTA tile. + :type cta_tile_shape: cute.Shape + :param layout_d: The layout enum of the output tensor D. + :type layout_d: LayoutEnum + :param elem_ty_d: The element type for output tensor D. + :type elem_ty_d: Type[Numeric] + :param elem_ty_acc: The element type for accumulation. + :type elem_ty_acc: Type[Numeric] + :param epi_tile: The epilogue tile configuration. + :type epi_tile: cute.Tile + :param use_2cta_instrs: A flag indicating whether the configuration is for 2 SMs. + :type use_2cta_instrs: bool + + :return: An instance of Sm100TmemLoad with the computed configuration. + :rtype: cute.CopyAtom + + :raises ValueError: If the function cannot handle the given combination of accumulation + and dimension types, or if it cannot determine the appropriate configuration based on + the input parameters. + """ + is_m_major = layout_d.is_m_major_c() + + acc_bits = elem_ty_acc.width + d_bits = elem_ty_d.width + + tmem_warp_shape_mn = ( + (2, 2) if (cta_tile_shape[0] == 64 and use_2cta_instrs) else (4, 1) + ) + epilog_tile_shape_mn = cute.product_each( + cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip + ) + epilog_warp_tile_shape_mn = cute.shape_div( + epilog_tile_shape_mn, tmem_warp_shape_mn, loc=loc, ip=ip + ) + + num_dp = cute.size(epilog_warp_tile_shape_mn[0], loc=loc, ip=ip) + if num_dp not in {16, 32}: + raise ValueError("Cta tile and 2sm config does not generate correct num dp.") + + num_col_bits = cute.size(epilog_warp_tile_shape_mn[1], loc=loc, ip=ip) * acc_bits + + tmem_dp = 0 + tmem_bit = 0 + tmem_rep = 0 + tmem_pack16b = False + if acc_bits == 32 and d_bits == 32: + if num_dp == 16: + if is_m_major: + tmem_dp = 16 + tmem_bit = 256 + else: + tmem_dp = 16 + tmem_bit = 128 + else: + tmem_dp = 32 + tmem_bit = 32 + elif acc_bits == 32 and d_bits == 16: + if num_dp == 16: + if is_m_major: + tmem_dp = 16 + tmem_bit = 256 + else: + tmem_dp = 16 + tmem_bit = 256 + else: + if is_m_major: + tmem_dp = 16 + tmem_bit = 256 + else: + tmem_dp = 32 + tmem_bit = 32 + elif acc_bits == 32 and d_bits == 8: + if num_dp == 16: + if is_m_major: + tmem_dp = 16 + tmem_bit = 256 + else: + tmem_dp = 16 + tmem_bit = 32 + else: + if is_m_major: + tmem_dp = 16 + tmem_bit = 256 + else: + tmem_dp = 32 + tmem_bit = 32 + elif acc_bits == 16 and d_bits == 16: + tmem_pack16b = True + if num_dp == 16: + if is_m_major: + tmem_dp = 16 + tmem_bit = 128 + else: + tmem_dp = 16 + tmem_bit = 128 + else: + if is_m_major: + tmem_dp = 16 + tmem_bit = 128 + else: + tmem_dp = 32 + tmem_bit = 32 + elif acc_bits == 32 and d_bits == 6: + if not num_dp == 32: + raise ValueError("Num dp must be 32.") + tmem_dp = 32 + tmem_bit = 32 + elif acc_bits == 32 and d_bits == 4: + if not num_dp == 32: + raise ValueError("Num dp must be 32.") + tmem_dp = 32 + tmem_bit = 32 + else: + raise ValueError( + f"Can not handle acc/d type combination: {elem_ty_acc=}, {elem_ty_d=}" + ) + + num_bit_div = tmem_bit + if tmem_dp == 16 and tmem_bit == 32: + num_bit_div = 64 + + if (num_col_bits % (num_bit_div * 128) == 0) and ( + (tmem_dp == 16 and tmem_bit == 64) + or (tmem_dp == 16 and tmem_bit == 32) + or (tmem_dp == 32 and tmem_bit == 32) + ): + tmem_rep = 128 + elif (num_col_bits % (num_bit_div * 64) == 0) and ( + (tmem_dp == 16 and tmem_bit == 128) + or (tmem_dp == 16 and tmem_bit == 64) + or (tmem_dp == 16 and tmem_bit == 32) + or (tmem_dp == 32 and tmem_bit == 32) + ): + tmem_rep = 64 + elif num_col_bits % (num_bit_div * 32) == 0: + tmem_rep = 32 + elif num_col_bits % (num_bit_div * 16) == 0: + tmem_rep = 16 + elif num_col_bits % (num_bit_div * 8) == 0: + tmem_rep = 8 + elif num_col_bits % (num_bit_div * 4) == 0: + tmem_rep = 4 + elif num_col_bits % (num_bit_div * 2) == 0: + tmem_rep = 2 + elif num_col_bits % (num_bit_div * 1) == 0: + tmem_rep = 1 + else: + raise ValueError("Can not pick tmem_rep based on cta tile shape and tmem atom.") + + if tmem_dp == 16 and tmem_bit == 64: + op = Ld16x64bOp( + Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE + ) + return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) + elif tmem_dp == 16 and tmem_bit == 128: + op = Ld16x128bOp( + Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE + ) + return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) + elif tmem_dp == 16 and tmem_bit == 256: + op = Ld16x256bOp( + Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE + ) + return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) + elif tmem_dp == 16 and tmem_bit == 32: + op = Ld16x32bx2Op( + Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE + ) + return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) + + elif tmem_dp == 32 and tmem_bit == 32: + op = Ld32x32bOp( + Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE + ) + return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) + else: + raise ValueError() + + +def get_num_tmem_alloc_cols( + tmem_tensors: Union[cute.Tensor, List[cute.Tensor]], rounding=True +) -> int: + """Get the total number of TMEM allocation columns for the given TMEM tensors. + + :param tmem_tensors: The TMEM tensors to get the number of allocation columns for. + :type tmem_tensors: Union[cute.Tensor, List[cute.Tensor]] + :param rounding: Whether to round up the number of allocation columns to the nearest power of 2. + :type rounding: bool + + :return: The total number of TMEM allocation columns. + :rtype: int + + :raises ValueError: If the number of TMEM allocation columns exceeds the maximum capacity of 512 or is less than 32. + """ + # Turn tmem_tensors into a list + if isinstance(tmem_tensors, cute.Tensor): + tmem_tensors = [tmem_tensors] + + # For each tensor in tmem_tensors, find the tmem_tensor_col_offset + num_tmem_alloc_cols_per_tensor = [ + find_tmem_tensor_col_offset(t) for t in tmem_tensors + ] + + # Sum up the num_tmem_alloc_cols_per_tensor + num_tmem_alloc_cols = sum(num_tmem_alloc_cols_per_tensor) + + # Round up num_tmem_cols_total to the nearest power of 2 + if rounding: + num_tmem_alloc_cols = 1 << ceil(log2(num_tmem_alloc_cols)) + + # Validate the number of TMEM allocation columns + SM100_TMEM_CAPACITY_COLUMNS = 512 + SM100_TMEM_MIN_ALLOC_COLUMNS = 32 + if ( + num_tmem_alloc_cols > SM100_TMEM_CAPACITY_COLUMNS + or num_tmem_alloc_cols < SM100_TMEM_MIN_ALLOC_COLUMNS + ): + raise ValueError( + f"TMEM allocation columns {num_tmem_alloc_cols} exceeds the maximum capacity of {SM100_TMEM_CAPACITY_COLUMNS} or less than {SM100_TMEM_MIN_ALLOC_COLUMNS}" + ) + return num_tmem_alloc_cols + + +def get_smem_layout_atom_ab( + major_mode: OperandMajorMode, + element_type: Type[Numeric], + smem_shape_mn_k: Tuple[int, int], + *, + loc=None, + ip=None, +) -> SmemLayoutAtomKind: + """Simple heuristics to select the optimal SMEM layout atom based on the + majorness, the data type, and the major mode size. + + :param major_mode: The major mode for the SMEM tensor is K major. + :type major_mode: OperandMajorMode + :param element_type: The element type for the SMEM tensor. + :type element_type: Type[Numeric] + :param smem_shape_mn_k: The shape of the SMEM tensor. + :type smem_shape_mn_k: Tuple[int, int] + + :return: The SMEM layout atom kind + :rtype: SmemLayoutAtomKind + """ + is_k_major = major_mode == OperandMajorMode.K + major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0] + + assert major_mode_size % 8 == 0 + sw128_num_contiguous_bits = 1024 + sw64_num_contiguous_bits = 512 + sw32_num_contiguous_bits = 256 + inter_num_contiguous_bits = 128 + major_mode_size_bits = major_mode_size * element_type.width + assert major_mode_size_bits % inter_num_contiguous_bits == 0 + + if not is_k_major: + if (element_type.width == 32) and ( + major_mode_size_bits % sw128_num_contiguous_bits == 0 + ): + return SmemLayoutAtomKind.MN_SW128_32B + if major_mode_size_bits % sw128_num_contiguous_bits == 0: + return SmemLayoutAtomKind.MN_SW128 + if major_mode_size_bits % sw64_num_contiguous_bits == 0: + return SmemLayoutAtomKind.MN_SW64 + if major_mode_size_bits % sw32_num_contiguous_bits == 0: + return SmemLayoutAtomKind.MN_SW32 + return SmemLayoutAtomKind.MN_INTER + if major_mode_size_bits % sw128_num_contiguous_bits == 0: + return SmemLayoutAtomKind.K_SW128 + if major_mode_size_bits % sw64_num_contiguous_bits == 0: + return SmemLayoutAtomKind.K_SW64 + if major_mode_size_bits % sw32_num_contiguous_bits == 0: + return SmemLayoutAtomKind.K_SW32 + return SmemLayoutAtomKind.K_INTER + + +@dsl_user_op +def make_smem_layout_a( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + a_dtype: Type[Numeric], + num_stages: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """This function helps with: + 1. Get the partitioned shape of the A tensor based on the tiled_mma & MMA tiler. + 2. Select the heuristic SMEM layout atom based on the A tensor's majorness, the data type, and the major mode size. + 3. cute.Tile the SMEM layout atom to the MMA tile shape. + 4. Stage the SMEM layout based on the number of stages. + + :param tiled_mma: The tiled MMA used to partition tensor A + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The MMA tile shape + :type mma_tiler_mnk: cute.cute.Tile + :param a_dtype: The element type for tensor A + :type a_dtype: Type[Numeric] + :param num_stages: The number of pipeline stages for tensor A + :type num_stages: int + + :return: SMEM layout for tensor A + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + + is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K + a_smem_shape = tiled_mma.partition_shape_A( + cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip) + ) + a_smem_shape_mn_k = ( + cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1], + cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2], + ) + a_smem_layout_atom = make_smem_layout_atom( + get_smem_layout_atom_ab( + tiled_mma.op.a_major_mode, + a_dtype, + a_smem_shape_mn_k, + loc=loc, + ip=ip, + ), + a_dtype, + loc=loc, + ip=ip, + ) + a_smem_layout_staged = tile_to_mma_shape( + a_smem_layout_atom, + cute.append(a_smem_shape, num_stages, loc=loc, ip=ip), + order=((1, 0, 2) if not is_k_major else (0, 1, 2)), + loc=loc, + ip=ip, + ) + return a_smem_layout_staged + + +@dsl_user_op +def make_smem_layout_b( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + b_dtype: Type[Numeric], + num_stages: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """This function helps: + 1. Get the partitioned shape of the B tensor based on the tiled_mma & MMA tiler. + 2. Select the heuristic SMEM layout atom based on the B tensor's majorness, the data type, and the major mode size. + 3. cute.Tile the SMEM layout atom to the MMA tile shape. + 4. Stage the SMEM layout based on the number of stages. + + :param tiled_mma: The tiled MMA which is used to partition the B tensor. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The MMA tile shape. + :type mma_tiler_mnk: cute.cute.Tile + :param b_dtype: The element type for the B tensor. + :type b_dtype: Type[Numeric] + :param num_stages: The stage of the B tensor. + :type num_stages: int + + :return: SMEM layout for the B tensor. + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + + is_k_major = tiled_mma.op.b_major_mode == OperandMajorMode.K + b_smem_shape = tiled_mma.partition_shape_B( + cute.dice(mma_tiler_mnk, (None, 1, 1), loc=loc, ip=ip) + ) + b_smem_shape_nk = ( + cute.size(b_smem_shape[0][0], loc=loc, ip=ip) * b_smem_shape[1], + cute.size(b_smem_shape[0][1], loc=loc, ip=ip) * b_smem_shape[2], + ) + b_smem_layout_atom = make_smem_layout_atom( + get_smem_layout_atom_ab( + tiled_mma.op.b_major_mode, + b_dtype, + b_smem_shape_nk, + loc=loc, + ip=ip, + ), + b_dtype, + loc=loc, + ip=ip, + ) + b_smem_layout_staged = tile_to_mma_shape( + b_smem_layout_atom, + cute.append(b_smem_shape, num_stages, loc=loc, ip=ip), + order=((1, 0, 2) if not is_k_major else (0, 1, 2)), + loc=loc, + ip=ip, + ) + + return b_smem_layout_staged + + +@dsl_user_op +def get_smem_layout_atom_epi( + layout: LayoutEnum, + element_type: Type[Numeric], + epi_tile: cute.Tile, + *, + loc=None, + ip=None, +) -> SmemLayoutAtomKind: + """Simple heuristics to select the optimal SMEM layout atom for epilog tensors. + + :param layout: The layout enum for the SMEM tensor. + :type layout: LayoutEnum + :param element_type: The element type for the SMEM tensor. + :type element_type: Type[Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + + :return: The SMEM layout atom kind + :rtype: SmemLayoutAtomKind + """ + # Get the max contiguous tile usable by TMA + tma_shape = tuple( + ( + # assumes get<0>(epi_tile) is coalesced and unit stride + cute.coalesce(cute.right_inverse(x, loc=loc, ip=ip), loc=loc, ip=ip).shape + if isinstance(x, cute.Layout) + else x + ) + for x in epi_tile + ) + + if layout.is_m_major_c(): + # ColMajor C/D (M-major) + return get_smem_layout_atom_ab( + OperandMajorMode.MN, element_type, tma_shape, loc=loc, ip=ip + ) + else: + # RowMajor C/D (N-major) + return get_smem_layout_atom_ab( + OperandMajorMode.K, element_type, tma_shape, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_smem_layout_epi( + epi_dtype: Type[Numeric], + epi_layout: LayoutEnum, + epi_tile: cute.Tile, + epi_stage: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """This function helps: + 1. Select the heuristic SMEM layout atom based on the epilog tile shape, + the epilog tensor's majorness, and the element type. + 2. cute.Tile the SMEM layout atom to the epilog tile shape. + 3. Stage the SMEM layout based on the number of stages. + + :param epi_dtype: The element type for the epilog tensor. + :type epi_dtype: Type[Numeric] + :param epi_layout: The layout enum for the epilog tensor. + :type epi_layout: LayoutEnum + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.cute.Tile + :param epi_stage: The stage of the epilog tensor. + :type epi_stage: int + + :return: SMEM layout for epilog tensors (usually C & D which are processed in the epilog) + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + + epilog_shape = cute.product_each( + cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip + ) + + c_smem_layout_atom = make_smem_layout_atom( + get_smem_layout_atom_epi( + epi_layout, + epi_dtype, + epi_tile, + loc=loc, + ip=ip, + ), + epi_dtype, + loc=loc, + ip=ip, + ) + epi_smem_layout_staged = cute.tile_to_shape( + c_smem_layout_atom, + cute.append(epilog_shape, epi_stage, loc=loc, ip=ip), + order=((1, 0, 2) if not epi_layout.is_n_major_c() else (0, 1, 2)), + loc=loc, + ip=ip, + ) + + return epi_smem_layout_staged + + +@dsl_user_op +def make_trivial_tiled_mma( + ab_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + acc_dtype: Type[Numeric], + cta_group: CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc=None, + ip=None, +) -> cute.TiledMma: + """Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. + By default, the MMA atom is created with SMEM operand source for A. + + :param ab_dtype: Data type of operands A and B. + :type ab_dtype: type[Numeric] + :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). + :type a_leading_mode: tcgen05.OperandMajorMode + :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). + :type b_leading_mode: tcgen05.OperandMajorMode + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[Numeric] + :param cta_group: The CTA group to use. + :type cta_group: tcgen05.CtaGroup + :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mn: Tuple[int, int] + :param a_source: The source of operand A (SMEM by default or TMEM). + :type a_source: OperandSource + + :return: A tiled MMA atom. + :rtype: cute.TiledMma + + :raises TypeError: If the data type is not supported. + """ + + if ab_dtype in {Float16, BFloat16}: + mma_op = MmaF16BF16Op( + ab_dtype, + acc_dtype, + (*mma_tiler_mn, 16), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + elif ab_dtype in {TFloat32, Float32}: + mma_op = MmaTF32Op( + (*mma_tiler_mn, 8), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + elif ab_dtype in { + Uint8, + Int8, + }: + mma_op = MmaI8Op( + ab_dtype, + (*mma_tiler_mn, 32), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + elif ab_dtype in {Float8E4M3FN, Float8E5M2}: + mma_op = MmaFP8Op( + ab_dtype, + acc_dtype, + (*mma_tiler_mn, 32), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + else: + raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") + + return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) + + +@dsl_user_op +def make_blockscaled_trivial_tiled_mma( + ab_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + sf_dtype: Type[Numeric], + sf_vec_size: int, + cta_group: CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc=None, + ip=None, +) -> cute.TiledMma: + """Make a BlockScaled tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. + By default, the MMA atom is created with SMEM operand source for A. + + :param ab_dtype: Data type of operands A and B. + :type ab_dtype: type[Numeric] + :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). + :type a_leading_mode: tcgen05.OperandMajorMode + :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). + :type b_leading_mode: tcgen05.OperandMajorMode + :param sf_dtype: Data type of the Scale Factor. + :type sf_dtype: type[Numeric] + :param sf_vec_size: The vector size of the Scale Factor. + :type sf_vec_size: int + :param cta_group: The CTA group to use. + :type cta_group: tcgen05.CtaGroup + :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mn: Tuple[int, int] + :param a_source: The source of operand A (SMEM by default or TMEM). + :type a_source: OperandSource + + :return: A tiled MMA atom. + :rtype: cute.TiledMma + + :raises TypeError: If the data type is not supported. + """ + if ab_dtype in {Float8E4M3FN, Float8E5M2}: + mma_op = MmaMXF8Op( + ab_dtype, + (*mma_tiler_mn, 32), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + elif ab_dtype == Float4E2M1FN: + if sf_vec_size == 32: + mma_op = MmaMXF4Op( + (*mma_tiler_mn, 64), + cta_group, + a_source, + ) + elif sf_vec_size == 16: + mma_op = MmaMXF4NVF4Op( + sf_dtype, + (*mma_tiler_mn, 64), + cta_group, + a_source, + ) + else: + raise ValueError(f"unsupported sf_vec_size, got {sf_vec_size}") + else: + raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") + + return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) + + +@dsl_user_op +def cluster_shape_to_tma_atom_A( + cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None +) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom for A based on the number of SMs and the multicast flag. + + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param atom_thr_id: The thread ID of the atom + :type atom_thr_id: cute.Layout + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + :raise ValueError: If the cluster shape is not divisible by the atom SM count + """ + atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) + mcast = not (cute.size(cluster_shape_mnk, mode=[1], loc=loc, ip=ip) == 1) + cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) + + if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): + raise ValueError( + f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: + raise ValueError( + f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if atom_sm_cnt == 2 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.ONE) + + raise ValueError( + f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" + ) + + +@dsl_user_op +def cluster_shape_to_tma_atom_B( + cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None +) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom for Bbased on the number of SMs and the multicast flag. + + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param atom_thr_id: The thread ID of the atom + :type atom_thr_id: cute.Layout + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + :raise ValueError: If the cluster shape is not divisible by the atom SM count + """ + atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) + mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == atom_sm_cnt) + cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) + + if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): + raise ValueError( + f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: + raise ValueError( + f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if atom_sm_cnt == 2 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.ONE) + + raise ValueError( + f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" + ) + + +@dsl_user_op +def cluster_shape_to_tma_atom_SFB( + cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None +) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom for SFB based on the number of SMs and the multicast flag. + + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param atom_thr_id: The thread ID of the atom + :type atom_thr_id: cute.Layout + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + :raise ValueError: If the cluster shape is not divisible by the atom SM count + """ + atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) + mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == 1) + cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) + + if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): + raise ValueError( + f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: + raise ValueError( + f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if atom_sm_cnt == 2: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.ONE) + + raise ValueError( + f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" + ) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blockscaled_layout.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blockscaled_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..fa1e2eb70e38236d73f435e001fdc160d301c47c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blockscaled_layout.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from dataclasses import dataclass, field +from typing import Union + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass.cute as cute +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + + +@dataclass(frozen=True) +class BlockScaledBasicChunk: + """ + The basic scale factor atom layout decided by tcgen05 BlockScaled MMA Ops. + + This class represents the fixed layout pattern for scale factors used in + tcgen05 BlockScaled MMA Ops. The layout is determined by the + instruction specification and cannot be modified. + See `PTX documentation `. + """ + + sf_vec_size: int + major_mode: OperandMajorMode = OperandMajorMode.K + _layout: cute.Layout = field(init=False, repr=False) + + def __post_init__(self) -> None: + if self.major_mode == OperandMajorMode.K: + # K-major layout: (AtomMN, AtomK) + atom_shape = ((32, 4), (self.sf_vec_size, 4)) + atom_stride = ((16, 4), (0, 1)) + else: + # MN-major layout: (AtomK, AtomMN) + atom_shape = ((self.sf_vec_size, 4), (32, 4)) + atom_stride = ((0, 1), (16, 4)) + + object.__setattr__( + self, "_layout", cute.make_layout(atom_shape, stride=atom_stride) + ) + + @property + def layout(self) -> cute.Layout: + """ + Get the layout for this block scaled chunk. + + :return: The layout representing the scale factor atom + :rtype: cute.Layout + """ + return self._layout + + +@dsl_user_op +def tile_atom_to_shape_SF( + Shape: cute.Shape, + sf_vec_size: int, + *, + loc=None, + ip=None, +) -> cute.Layout: + """ + A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout. + + :param Shape: The shape of the A/B tensor + :param sf_vec_size: Scale factor vector size + + :return: The layout of the SFA/SFB tensor + :rtype: cute.Layout + """ + # ((Atom_MN, Rest_MN),(Atom_K, Rest_K),RestL) + sf_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, Shape, (2, 1, 3) + ) + return sf_layout + + +@dsl_user_op +def make_smem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc=None, + ip=None, +) -> cute.Layout: + """ + Make smem layout for SFA based on: + 1. BlockScaledBasicChunk + 2. MMA tiler shape + 3. Scale factor vector size + 4. Number of stages + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + # (CTA_Tile_Shape_M, MMA_Tile_Shape_K) + sfa_tile_shape = ( + mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape), + mma_tiler_mnk[2], + ) + + # ((Atom_M, Rest_M),(Atom_K, Rest_K)) + smem_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, + sfa_tile_shape, + (2, 1), + ) + + mma_tile_inst_k = 4 + # (CTA_Tile_Shape_M, MMA_Inst_Shape_K) + sfa_tile_shape = cute.shape_div(sfa_tile_shape, (1, mma_tile_inst_k)) + # ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K)) + smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape) + + atom_m = 128 + tiler_inst = ((atom_m, sf_vec_size),) + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K) + smem_layout = cute.logical_divide(smem_layout, tiler_inst) + + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) + sfa_smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + ), + ) + + return sfa_smem_layout_staged + + +@dsl_user_op +def make_smem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc=None, + ip=None, +) -> cute.Layout: + """ + Make smem layout for SFB based on: + 1. BlockScaledBasicChunk + 2. MMA tiler shape + 3. Scale factor vector size + 4. Number of stages + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + # (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K) + sfb_tile_shape = ( + cute.round_up(mma_tiler_mnk[1], 128), + mma_tiler_mnk[2], + ) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K)) + smem_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, + sfb_tile_shape, + (2, 1), + ) + + mma_tile_inst_k = 4 + # (CTA_Tile_Shape_N, MMA_Inst_Shape_K) + sfb_tile_shape = cute.shape_div(sfb_tile_shape, (1, mma_tile_inst_k)) + # ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K) + smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape) + + atom_n = 128 + tiler_inst = ((atom_n, sf_vec_size),) + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K) + smem_layout = cute.logical_divide(smem_layout, tiler_inst) + + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) + sfb_smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + ), + ) + + return sfb_smem_layout_staged + + +@dsl_user_op +def make_tmem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + smem_layout: cute.Layout, + *, + loc=None, + ip=None, +) -> cute.Layout: + """Make tmem layout for SFA based on: + 1. SFA smem layout per stage + 2. Cta tile shape m + 3. tiled MMA atom thr size + 4. Scale factor vector size + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param smem_layout: The smem layout of SFA per stage + :type smem_layout: cute.Layout + + :return: TMEM layout for SFA + :rtype: cute.Layout + """ + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size + + sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa( + smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size + ) + return _cute_ir.static(sfa_layout_ty, loc=loc, ip=ip) + + +@dsl_user_op +def make_tmem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + smem_layout: cute.Layout, + *, + loc=None, + ip=None, +) -> cute.Layout: + """Make tmem layout for SFB based on: + 1. SFB smem layout per stage + 2. Cta tile shape m + 3. tiled MMA atom thr size + 4. Scale factor vector size + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param smem_layout: The smem layout of SFB per stage + :type smem_layout: cute.Layout + + :return: TMEM layout for SFB + :rtype: cute.Layout + """ + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size + + sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb( + smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size + ) + return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/distributed_helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/distributed_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..5853c56c84f6fc02e911537147fa03b6b4566117 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/distributed_helpers.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from functools import partial +from typing import Tuple + +import cutlass.cute as cute +from cutlass.cutlass_dsl import T, dsl_user_op, while_generate + +from cutlass._mlir import ir +from cutlass._mlir.dialects import arith, llvm, nvvm, scf +from cutlass._mlir.dialects.nvvm import ( + MemOrderKind, + MemScopeKind, + AtomicOpKind, +) +from cutlass.cute.typing import Pointer, Int32, Boolean + + +@dsl_user_op +def atomicAdd(dst_ptr: Pointer, val: Int32, loc=None, ip=None) -> Int32: + return nvvm.atomicrmw( + T.i32(), + AtomicOpKind.ADD, + dst_ptr.llvm_ptr, + val.ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.RELAXED, + syncscope=MemScopeKind.SYS, + loc=loc, + ip=ip, + ) + + +@cute.jit +def ld_bypass(input_tensor: cute.Tensor): + fragment = cute.make_fragment(input_tensor.layout, input_tensor.element_type) + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + input_tensor.element_type, + memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, + memory_scope=cute.nvgpu.common.MemoryScope.SYS, + ) + cute.copy(copy_atom_load, input_tensor, fragment) + vals = fragment.load() + return vals + +@cute.jit +def spin_lock_wait(lock_ptr: Pointer, expect_count: Int32, mem_order : str = "relaxed", mem_scope : str = "gpu", loc=None, ip=None) -> None: + """ + wait on a spin lock until the expected count is reached. + """ + res = 0 + while res != expect_count: + res = nvvm.atomicrmw( + T.i32(), + AtomicOpKind.CAS, + lock_ptr.llvm_ptr, + Int32(0).ir_value(loc=loc, ip=ip), + b=Int32(expect_count).ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED, + syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS + ) + + +@dsl_user_op +def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None: + """ + add 1 to the multimem address + """ + llvm.inline_asm( + None, + [mc_ptr.toint().ir_value()], + "multimem.red.release.sys.global.add.u32 [$0], 1;", + "l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + +@dsl_user_op +def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None: + """ + add 1 to the multimem address + """ + llvm.inline_asm( + None, + [mc_ptr.toint().ir_value()], + "multimem.red.relaxed.gpu.global.add.u32 [$0], 1;", + "l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + + +def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None: + """ + arrive a spin lock when the lock_ptr is a multimem address. + """ + multimem_red_add_gpu_relaxed(lock_ptr, loc=loc, ip=ip) + + +def sm_wise_inter_gpu_multimem_barrier(barrier : Pointer, barrier_mc : Pointer, num_ranks, loc=None, ip=None) -> None : + """ + barrier for inter-gpu sm-wise + """ + bidx, bidy, bidz = cute.arch.block_idx() + bdimx, bdimy, _ = cute.arch.grid_dim() + pid = bidx + bidy * bdimx + bidz * bdimx * bdimy + multimem_red_add_sys_release(barrier_mc + pid, loc=loc, ip=ip) + cute.arch.fence_proxy(cute.arch.ProxyKind.alias) + spin_lock_wait(barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip) + + +@dsl_user_op +def multimem_ld_reduce_base( + mc_ptr: Pointer, + *, + ptx_string: str = "", + loc=None, + ip=None, +) -> Tuple[Int32, Int32, Int32, Int32]: + # ld reduce 8xf16 elts + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() + return_struct = llvm.inline_asm( + ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"), + [mc_ptr_int], + ptx_string, + "=r,=r,=r,=r,l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(4)] + return return_regs[0], return_regs[1], return_regs[2], return_regs[3] + + +multimem_ld_reduce_8xf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_4xf32 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_8xbf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_16xe4m3 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_16xe5m2 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];") + + +@dsl_user_op +def multimem_st_4xb32( + mc_ptr: Pointer, + x: Int32, + y: Int32, + z: Int32, + w: Int32, + *, + loc=None, + ip=None, +) -> None: + # st 4x32 bits of data + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + T.i32(), + [mc_ptr_int, x, y, z, w], + "multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};", + "=r,l,r,r,r,r", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..a51bae62963bd482fd590f824a4bc1c8564ece0e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py @@ -0,0 +1,466 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import List, Tuple + +import cutlass.cute as cute +from cutlass.cutlass_dsl import Int32, extract_mlir_values, new_from_mlir_values +from cutlass._mlir import ir + +from cutlass.utils.static_persistent_tile_scheduler import PersistentTileSchedulerParams + + +class GroupSearchResult: + """ + The result of the group search for grouped gemm. + + :param group_idx: The result group index + :type group_idx: Int32 + :param cta_tile_idx_m: CTA tile index along M dimension after rasterization + :type cta_tile_idx_m: Int32 + :param cta_tile_idx_n: CTA tile index along N dimension after rasterization + :type cta_tile_idx_n: Int32 + :param problem_shape_m: The M dimension of the gemm problem + :type problem_shape_m: Int32 + :param problem_shape_n: The N dimension of the gemm problem + :type problem_shape_n: Int32 + :param problem_shape_k: The K dimension of the gemm problem + :type problem_shape_k: Int32 + :param cta_tile_count_k: Number of tiles along K dimension + :type cta_tile_count_k: Int32 + """ + + def __init__( + self, + group_idx: Int32, + cta_tile_idx_m: Int32, + cta_tile_idx_n: Int32, + problem_shape_m: Int32, + problem_shape_n: Int32, + problem_shape_k: Int32, + cta_tile_count_k: Int32, + ) -> None: + self.group_idx = group_idx + self.cta_tile_idx_m = cta_tile_idx_m + self.cta_tile_idx_n = cta_tile_idx_n + self.problem_shape_m = problem_shape_m + self.problem_shape_n = problem_shape_n + self.problem_shape_k = problem_shape_k + self.cta_tile_count_k = cta_tile_count_k + + def __extract_mlir_values__(self) -> List[ir.Value]: + values = extract_mlir_values(self.group_idx) + values.extend(extract_mlir_values(self.cta_tile_idx_m)) + values.extend(extract_mlir_values(self.cta_tile_idx_n)) + values.extend(extract_mlir_values(self.problem_shape_m)) + values.extend(extract_mlir_values(self.problem_shape_n)) + values.extend(extract_mlir_values(self.problem_shape_k)) + values.extend(extract_mlir_values(self.cta_tile_count_k)) + return values + + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "GroupSearchResult": + assert len(values) == 7 + return GroupSearchResult(*tuple(values)) + + +class GroupedGemmGroupSearchState: + """ + The state of group index search for grouped gemm. + + The state will be initialized once and updated in every round of group index search. + + :param start_group_idx: The group idx to start the search with + :type start_group_idx: Int32 + :param tile_count_prev_group: Number of tiles before the matched group + :type tile_count_prev_group: Int32 + :param tile_count_searched: Number of tiles we have searched. When the matched group is found, + it records the number of tiles including the matched group + :type tile_count_searched: Int32 + """ + + def __init__( + self, + start_group_idx: Int32, + tile_count_prev_group: Int32, + tile_count_searched: Int32, + ) -> None: + self.start_group_idx = start_group_idx + self.tile_count_prev_group = tile_count_prev_group + self.tile_count_searched = tile_count_searched + + def __extract_mlir_values__(self) -> List[ir.Value]: + values = extract_mlir_values(self.start_group_idx) + values.extend(extract_mlir_values(self.tile_count_prev_group)) + values.extend(extract_mlir_values(self.tile_count_searched)) + return values + + def __new_from_mlir_values__( + self, values: List[ir.Value] + ) -> "GroupedGemmGroupSearchState": + start_group_idx = new_from_mlir_values(self.start_group_idx, [values[0]]) + tile_count_prev_group = new_from_mlir_values( + self.tile_count_prev_group, [values[1]] + ) + tile_count_searched = new_from_mlir_values( + self.tile_count_searched, [values[2]] + ) + return GroupedGemmGroupSearchState( + start_group_idx, tile_count_prev_group, tile_count_searched + ) + + +def create_initial_search_state() -> GroupedGemmGroupSearchState: + """ + Create an initial search state for grouped gemm. + + :return: A new search state with initial values + :rtype: GroupedGemmGroupSearchState + """ + return GroupedGemmGroupSearchState( + start_group_idx=Int32(0), + tile_count_prev_group=Int32(0), + tile_count_searched=Int32(0), + ) + + +class GroupedGemmTileSchedulerHelper: + """ + A helper to translate the raw block index (x, y, z) from tile scheduler to real CTA tile index for grouped gemm. + + :param group_count: Number of groups in current grouped gemm problem + :type group_count: int + :param tile_sched_params: Parameter used to create the tile scheduler this helper works with + :type tile_sched_params: PersistentTileSchedulerParams + :param cluster_tile_shape_mnk: The shape of cluster tile as (m, n, k) + :type cluster_tile_shape_mnk: tuple[int, int, int] + :param search_state: The initial search state + :type search_state: GroupedGemmGroupSearchState + """ + + def __init__( + self, + group_count: int, + tile_sched_params: PersistentTileSchedulerParams, + cluster_tile_shape_mnk: tuple[int, int, int], + search_state: GroupedGemmGroupSearchState, + ) -> None: + self.tile_sched_params = tile_sched_params + self.group_count = group_count + self.lane_idx = cute.arch.lane_idx() + self.cluster_tile_shape_mnk = cluster_tile_shape_mnk + self.search_state = search_state + + def __extract_mlir_values__(self) -> List[ir.Value]: + values = extract_mlir_values(self.tile_sched_params) + values.extend(extract_mlir_values(self.search_state)) + return values + + def __new_from_mlir_values__( + self, values: List[ir.Value] + ) -> "GroupedGemmTileSchedulerHelper": + tile_sched_params = new_from_mlir_values(self.tile_sched_params, values) + search_state = new_from_mlir_values(self.search_state, values[1:]) + return GroupedGemmTileSchedulerHelper( + self.group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + search_state, + ) + + def delinearize_z( + self, + cta_tile_coord: tuple, + problem_shape_mnkl: cute.Tensor, + ) -> GroupSearchResult: + """ + Delinearize the linear z index and return GroupSearchResult. + + This function should be used by warps that need to know the CTA tile index on M and N dimensions. + + :param cta_tile_coord: The raw CTA coordinate from tile scheduler + :type cta_tile_coord: tuple of Int32 + :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for each group + :type problem_shape_mnkl: cute.Tensor + :return: The search result containing group index and tile coordinates + :rtype: GroupSearchResult + """ + # delinear the z coord + linear_idx = cta_tile_coord[2] + group_idx, problem_mnkl = self._group_search_and_load_problem_shape( + linear_idx, + problem_shape_mnkl, + self.search_state.start_group_idx, + self.search_state.tile_count_prev_group, + ) + # linear index local to current group + cluster_tile_idx_in_current_group = ( + linear_idx - self.search_state.tile_count_prev_group + ) + cluster_count_m, cluster_count_n, cluster_count_k = cute.ceil_div( + (problem_mnkl[0], problem_mnkl[1], problem_mnkl[2]), + ( + self.cluster_tile_shape_mnk[0], + self.cluster_tile_shape_mnk[1], + self.cluster_tile_shape_mnk[2], + ), + ) + # decompose to get indices on M and N + cta_tile_idx_m, cta_tile_idx_n = self._compute_cta_tile_coord( + cluster_tile_idx_in_current_group, + cta_tile_coord, + cluster_count_m, + cluster_count_n, + ) + return GroupSearchResult( + group_idx, + cta_tile_idx_m, + cta_tile_idx_n, + problem_mnkl[0], + problem_mnkl[1], + problem_mnkl[2], + cluster_count_k, + ) + + def search_cluster_tile_count_k( + self, + cta_tile_coord: tuple, + problem_shape_mnkl: cute.Tensor, + ) -> Tuple[Int32, Int32]: + """ + Search the matched group for given linear index and compute the number of tiles along K dimension for the matched group. + + This function should be used by warps that are only interested in the number of tiles along K dimension. + + :param cta_tile_coord: The raw CTA coordinate from tile scheduler + :type cta_tile_coord: tuple of Int32 + :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups + :type problem_shape_mnkl: cute.Tensor + :return: A tuple containing cluster count along K dimension and the group index + :rtype: Tuple[Int32, Int32] + """ + group_idx, problem_mnk = self._group_search_and_load_problem_shape( + cta_tile_coord[2], + problem_shape_mnkl, + self.search_state.start_group_idx, + self.search_state.tile_count_prev_group, + ) + cluster_count_k = ( + problem_mnk[2] + self.cluster_tile_shape_mnk[2] - 1 + ) // self.cluster_tile_shape_mnk[2] + return cluster_count_k, group_idx + + @cute.jit + def _prefix_sum(self, value_per_thread: Int32) -> Int32: + """ + Perform prefix sum within a full warp. + + :param value_per_thread: The value for this thread to contribute to the prefix sum + :type value_per_thread: Int32 + :return: The prefix sum result for this thread + :rtype: Int32 + """ + clamp_value = 0 + idx = 1 + sum_per_thread = value_per_thread + while idx < cute.arch.WARP_SIZE: + value = cute.arch.shuffle_sync_up( + sum_per_thread, idx, mask_and_clamp=clamp_value + ) + if self.lane_idx >= idx: + sum_per_thread += value + idx = idx << 1 + return sum_per_thread + + def _get_problem_for_group( + self, problem_shape_mnkl: cute.Tensor, group_idx: Int32 + ) -> cute.Tensor: + """ + Load gemm problem (m,n,k,l) for the specified group from global memory to register. + + :param problem_shape_mnkl: Tensor in global memory with layout (group_count, 4):(4, 1) + :type problem_shape_mnkl: cute.Tensor + :param group_idx: The index of the group to load + :type group_idx: Int32 + :return: The problem shape tensor for the specified group + :rtype: cute.Tensor + """ + cur_problem_mnkl = cute.make_fragment( + cute.make_layout(4), problem_shape_mnkl.element_type + ) + cute.autovec_copy(problem_shape_mnkl[(group_idx, None)], cur_problem_mnkl) + return cur_problem_mnkl + + def _get_cluster_tile_count_mn(self, problem_shape: cute.Tensor) -> Int32: + """ + Compute total cluster count. + + :param problem_shape: Tensor containing problem shape (m, n, k, l) + :type problem_shape: cute.Tensor + :return: The total cluster tile count for M and N dimensions + :rtype: Int32 + """ + cur_ntile_m = ( + problem_shape[0] + self.cluster_tile_shape_mnk[0] - 1 + ) // self.cluster_tile_shape_mnk[0] + cur_ntile_n = ( + problem_shape[1] + self.cluster_tile_shape_mnk[1] - 1 + ) // self.cluster_tile_shape_mnk[1] + cur_ntile_mn = cur_ntile_m * cur_ntile_n + return cur_ntile_mn + + def _compute_cta_tile_coord( + self, + cluster_tile_idx: Int32, + cta_tile_coord_in_cluster: tuple, + cluster_tile_count_m: Int32, + cluster_tile_count_n: Int32, + ) -> tuple: + """ + Compute CTA tile indices along M and N dimensions based on the linear index within a group. + + It uses the AlongM mode to decompose the linear index onto M and N dimensions. + + :param cluster_tile_idx: The linear index within a group + :type cluster_tile_idx: Int32 + :param cta_tile_coord_in_cluster: CTA indices along M and N dimensions within a cluster + :type cta_tile_coord_in_cluster: tuple of Int32 + :param cluster_tile_count_m: The number of clusters along M dimension of the matched group + :type cluster_tile_count_m: Int32 + :param cluster_tile_count_n: The number of clusters along N dimension of the matched group + :type cluster_tile_count_n: Int32 + :return: A tuple containing CTA tile indices along M and N dimensions + :rtype: tuple of (Int32, Int32) + """ + cluster_layout_mn = cute.make_layout( + (cluster_tile_count_m, cluster_tile_count_n) + ) + (mi, ni) = cluster_layout_mn.get_hier_coord(cluster_tile_idx) + cta_tile_idx_m = ( + mi * self.tile_sched_params.cluster_shape_mn[0] + + cta_tile_coord_in_cluster[0] + ) + cta_tile_idx_n = ( + ni * self.tile_sched_params.cluster_shape_mn[1] + + cta_tile_coord_in_cluster[1] + ) + return (cta_tile_idx_m, cta_tile_idx_n) + + @cute.jit + def _group_search( + self, + linear_idx: Int32, + problem_shape_mnkl: cute.Tensor, + init_group_idx: Int32, + init_tile_count_searched: Int32, + ) -> GroupedGemmGroupSearchState: + """ + Search which group the linear index belongs to. + + :param linear_idx: The linear index to be decomposed + :type linear_idx: Int32 + :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups + :type problem_shape_mnkl: cute.Tensor + :param init_group_idx: The group idx to start the search with + :type init_group_idx: Int32 + :param init_tile_count_searched: The number of tiles we have searched + :type init_tile_count_searched: Int32 + :return: The updated search state + :rtype: GroupedGemmGroupSearchState + """ + c_0 = Int32(0).ir_value() + last_lane_idx = cute.arch.WARP_SIZE - 1 + + tile_count_searched = init_tile_count_searched + start_group_idx = init_group_idx + not_found = linear_idx >= tile_count_searched + tile_count_prev_group = self.search_state.tile_count_prev_group + while not_found: + # get group to search for current lane + cur_group_idx = start_group_idx + self.lane_idx + # check if the group to be checked is out of range + inside_group_bound = cur_group_idx < self.group_count + cur_ntile_mn = c_0 + if inside_group_bound: + # get problem size of current group + cur_problem_mnkl = self._get_problem_for_group( + problem_shape_mnkl, cur_group_idx + ) + cur_ntile_mn = self._get_cluster_tile_count_mn(cur_problem_mnkl) + # compute tile count from beginning to current group(included) + total_cluster_tile_count_ps_per_thread = self._prefix_sum(cur_ntile_mn) + cluster_tile_count_end_per_thread = ( + total_cluster_tile_count_ps_per_thread + tile_count_searched + ) + + group_not_in_window = linear_idx >= cluster_tile_count_end_per_thread + hitted_group_idx_in_search_window = cute.arch.popc( + cute.arch.vote_ballot_sync(group_not_in_window) + ) + not_found = hitted_group_idx_in_search_window == cute.arch.WARP_SIZE + start_group_idx = hitted_group_idx_in_search_window + start_group_idx + hit_the_1st_problem_in_search_window = ( + hitted_group_idx_in_search_window == c_0 + ) + tile_count_prev_group = tile_count_searched + if hit_the_1st_problem_in_search_window == False: + tile_count_prev_group = cute.arch.shuffle_sync( + cluster_tile_count_end_per_thread, + hitted_group_idx_in_search_window - 1, + ) + + # If no matched group, then get new_cluster_tile_count_end from last lane + # Otherwise, get new_cluster_tile_count_end from the hitted group + lane_idx_for_cluster_tile_count_end = hitted_group_idx_in_search_window + if not_found: + lane_idx_for_cluster_tile_count_end = last_lane_idx + tile_count_searched = cute.arch.shuffle_sync( + cluster_tile_count_end_per_thread, + lane_idx_for_cluster_tile_count_end, + ) + + return GroupedGemmGroupSearchState( + start_group_idx, + tile_count_prev_group, + tile_count_searched, + ) + + def _group_search_and_load_problem_shape( + self, + linear_idx: Int32, + problem_shape_mnkl: cute.Tensor, + start_group_idx: Int32, + tile_count_searched: Int32, + ) -> Tuple[Int32, cute.Tensor]: + """ + Perform group search and load problem shape for the matched group. + + :param linear_idx: The linear index to be decomposed + :type linear_idx: Int32 + :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups + :type problem_shape_mnkl: cute.Tensor + :param start_group_idx: The group idx to start the search with + :type start_group_idx: Int32 + :param tile_count_searched: The number of tiles we have searched + :type tile_count_searched: Int32 + :return: A tuple containing the final group index and the problem shape tensor + :rtype: Tuple[Int32, cute.Tensor] + """ + self.search_state = self._group_search( + linear_idx, + problem_shape_mnkl, + start_group_idx, + tile_count_searched, + ) + # get final group search state + final_group_idx = self.search_state.start_group_idx + # let's revisit if it's better to broadcast problem_shape_mnk in group_search + problem_mnkl = self._get_problem_for_group(problem_shape_mnkl, final_group_idx) + return final_group_idx, problem_mnkl diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hardware_info.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hardware_info.py new file mode 100644 index 0000000000000000000000000000000000000000..e86fcbefc86fbc7da333735fa2cebbd3af47f39e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hardware_info.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from cuda.bindings import driver, nvrtc + +import cutlass.cute as cute + +""" +This class is used to get the hardware info of given GPU device. +It provides methods to get the max active clusters for given cluster size. + +Prerequisite: +- CUDA driver is initialized via `driver.cuInit` or other CUDA APIs. +- CUDA context is created via `driver.cuCtxCreate` or other CUDA APIs. + +""" + + +class HardwareInfo: + """ + device_id: CUDA device ID to get the hardware info. + """ + + def __init__(self, device_id: int = 0): + count = self._checkCudaErrors(driver.cuDeviceGetCount()) + if device_id >= count: + raise ValueError( + f"Device ID {device_id} is out of range for device count {count}" + ) + self.device_id = device_id + self.device = self._checkCudaErrors(driver.cuDeviceGet(device_id)) + self.context = self._checkCudaErrors(driver.cuCtxGetCurrent()) + self.driver_version = self._checkCudaErrors(driver.cuDriverGetVersion()) + + # Getting the max active clusters for a given cluster size + def get_max_active_clusters(self, cluster_size: int) -> int: + self._get_device_function() + if self._cuda_driver_version_lt(11, 8): + raise RuntimeError( + "CUDA Driver version < 11.8, cannot get _max_active_clusters" + ) + if cluster_size <= 0 or cluster_size > 32: + raise ValueError( + f"Cluster size must be between 1 and 32, {cluster_size} is not supported" + ) + + max_shared_memory_per_block = self._checkCudaErrors( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + self.device, + ) + ) + self._checkCudaErrors( + driver.cuFuncSetAttribute( + self.kernel, + driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + max_shared_memory_per_block, + ) + ) + max_dynamic_shared_memory = self._checkCudaErrors( + driver.cuOccupancyAvailableDynamicSMemPerBlock( + self.kernel, 1, 1 # numBlocks # blockSize + ) + ) + max_active_blocks = self._checkCudaErrors( + driver.cuOccupancyMaxActiveBlocksPerMultiprocessor( + self.kernel, 1, max_dynamic_shared_memory # blockSize, + ) + ) + # allow non-portable cluster size to support detection of non-portable cluster size + self._checkCudaErrors( + driver.cuFuncSetAttribute( + self.kernel, + driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, + 1, + ) + ) + # prepare launch configuration + launch_config = driver.CUlaunchConfig() + launch_config.blockDimX = 128 + launch_config.blockDimY = 1 + launch_config.blockDimZ = 1 + launch_config.sharedMemBytes = max_dynamic_shared_memory + launch_config.numAttrs = 1 + # max possible cluster size is 32 + cluster_dims_attr = driver.CUlaunchAttribute() + cluster_dims_attr.id = ( + driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + ) + value = driver.CUlaunchAttributeValue() + value.clusterDim.x = cluster_size + value.clusterDim.y = 1 + value.clusterDim.z = 1 + cluster_dims_attr.value = value + launch_config.attrs = [cluster_dims_attr] + launch_config.gridDimX = cluster_size + launch_config.gridDimY = max_active_blocks + launch_config.gridDimZ = 1 + + num_clusters = self._checkCudaErrors( + driver.cuOccupancyMaxActiveClusters(self.kernel, launch_config) + ) + return num_clusters + + def get_l2_cache_size_in_bytes(self) -> int: + return self._checkCudaErrors( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, + self.device, + ) + ) + + def get_device_multiprocessor_count(self) -> int: + return self._checkCudaErrors( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, + self.device, + ) + ) + + def _checkCudaErrors(self, result) -> None: + if result[0].value: + raise RuntimeError( + "CUDA error code={}({})".format( + result[0].value, self._cudaGetErrorEnum(result[0]) + ) + ) + # CUDA APIs always return the status as the first element of the result tuple + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + def _cudaGetErrorEnum(self, error) -> str: + if isinstance(error, driver.CUresult): + err, name = driver.cuGetErrorName(error) + return name if err == driver.CUresult.CUDA_SUCCESS else "" + elif isinstance(error, nvrtc.nvrtcResult): + return nvrtc.nvrtcGetErrorString(error)[1] + else: + raise RuntimeError("Unknown error type: {}".format(error)) + + def _cuda_driver_version_ge(self, major: int, minor: int) -> bool: + return self.driver_version >= (major * 1000 + 10 * minor) + + def _cuda_driver_version_lt(self, major: int, minor: int) -> bool: + return not self._cuda_driver_version_ge(major, minor) + + @cute.kernel + def _empty_kernel(self): + return + + @cute.jit + def _host_function(self): + self._empty_kernel().launch( + grid=[1, 1, 1], + block=[1, 1, 1], + ) + + # get a empty kernel to compute occupancy + def _get_device_function(self) -> None: + self.compiled_kernel = cute.compile(self._host_function) + self.module = next(iter(self.compiled_kernel.cuda_modules.modules)).cuda_module + self.kernel = next(iter(self.compiled_kernel.cuda_modules.modules)).kernel_ptr diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hopper_helpers.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hopper_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd2bae3de66983dc5bf7883305f6a926b3c0d72 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hopper_helpers.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Type, Tuple +from enum import Enum +from typing_extensions import deprecated +import warnings + +from cutlass.utils.layout import LayoutEnum +from cutlass.cutlass_dsl import ( + Float16, + BFloat16, + Float8E5M2, + Float8E4M3FN, + Numeric, + NumericMeta, + dsl_user_op, +) + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu.common import CopyUniversalOp +from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp +from cutlass.cute.nvgpu.warpgroup import ( + MmaF16BF16Op, + MmaF8Op, + OperandMajorMode, + OperandSource, +) + + +@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") +class SmemCapacity(Enum): + SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 + + +warnings.warn( + "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", + DeprecationWarning, + stacklevel=2, +) +# Dictionary to map compute capability to SMEM capacity +SMEM_CAPACITY = { + "sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value, +} + + +@dsl_user_op +def sm90_get_smem_store_op( + layout_d: LayoutEnum, + elem_ty_d: Type[Numeric], + elem_ty_acc: Type[Numeric], + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """ + Selects the largest vectorized smem store atom available subject to constraint of gmem layout. + + Parameters: + ----------- + layout_d : LayoutEnum + The layout enum of the output tensor D. + + elem_ty_d : Type[Numeric] + The element type for output tensor D. + + elem_ty_acc : Type[Numeric] + The element type for accumulator. + + Returns: + -------- + Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters. + """ + + def validate_type(ty, ty_name): + if not isinstance(ty, NumericMeta): + raise TypeError(f"{ty_name} must be a Numeric, but got {ty}") + + validate_type(elem_ty_d, "elem_ty_d") + validate_type(elem_ty_acc, "elem_ty_acc") + + is_m_major = layout_d.is_m_major_c() + + if elem_ty_d.width == 16: + return cute.make_copy_atom( + StMatrix8x8x16bOp(is_m_major, 4), elem_ty_d, loc=loc, ip=ip + ) + else: + return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip) + + +def make_trivial_tiled_mma( + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + acc_dtype: Type[Numeric], + atom_layout_mnk: Tuple[int, int, int], + tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc=None, + ip=None, +) -> cute.TiledMma: + """Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. + By default, the MMA atom is created with SMEM operand source for A. + + :param a_dtype: Data type of operand A. + :type a_dtype: type[Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[Numeric] + :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). + :type a_leading_mode: warpgroup.OperandMajorMode + :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). + :type b_leading_mode: warpgroup.OperandMajorMode + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[Numeric] + :param atom_layout_mnk: A integer tuple describing the tiling of Atom across threads. + :type atom_layout_mnk: Tuple[int, int, int] + :param tiler_mn: The shape (M, N) of the cta tiler. + :type tiler_mn: Tuple[int, int] + + :return: A tiled MMA atom. + :rtype: cute.TiledMma + + :raises TypeError: If the data type is not supported. + """ + + if a_dtype in {Float16, BFloat16}: + if cutlass.const_expr(a_dtype != b_dtype): + raise TypeError(f"Type mismatch: {a_dtype} != {b_dtype}") + if cutlass.const_expr(a_dtype.width != b_dtype.width): + raise TypeError(f"Type width mismatch: {a_dtype.width} != {b_dtype.width}") + + mma_op = MmaF16BF16Op( + a_dtype, + acc_dtype, + (*tiler_mn, 16), + a_source, + a_leading_mode, + b_leading_mode, + ) + elif a_dtype in {Float8E4M3FN, Float8E5M2} and b_dtype in { + Float8E4M3FN, + Float8E5M2, + }: + mma_op = MmaF8Op( + a_dtype, + b_dtype, + acc_dtype, + (*tiler_mn, 32), + a_source, + a_leading_mode, + b_leading_mode, + ) + else: + raise TypeError(f"unsupported a_dtype and b_dtype, got {a_dtype} and {b_dtype}") + + return cute.make_tiled_mma(cute.make_mma_atom(mma_op), atom_layout_mnk) + +def get_smem_layout_atom( + layout: LayoutEnum, + element_type: Type[Numeric], + major_mode_size: int, + *, + loc=None, + ip=None, +): + """Select the optimal shared memory layout atom based on parameters. + + :param layout: Layout enum of the tensor + :type layout: LayoutEnum + :param element_type: Data type of the elements + :type element_type: type[cutlass.Numeric] + :param major_mode_size: Size of the major mode dimension + :type major_mode_size: int + + :return: Selected shared memory layout atom kind + :rtype: cute.nvgpu.warpgroup.SmemLayoutAtomKind + """ + assert major_mode_size % 8 == 0 + sw128_num_contiguous_bits = 1024 + sw64_num_contiguous_bits = 512 + sw32_num_contiguous_bits = 256 + major_mode_size_bits = major_mode_size * element_type.width + if layout.sm90_mma_major_mode() == OperandMajorMode.MN: + if major_mode_size_bits % sw128_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128 + if major_mode_size_bits % sw64_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW64 + if major_mode_size_bits % sw32_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW32 + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_INTER + if major_mode_size_bits % sw128_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128 + if major_mode_size_bits % sw64_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64 + if major_mode_size_bits % sw32_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32 + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/layout.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..4560c266cf9930ac024adeaa94859d06ecf3650a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/layout.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from enum import Enum + +import cutlass.cute as cute +from cutlass.cute.nvgpu import warpgroup +from cutlass.cute.nvgpu import tcgen05 + + +class LayoutEnum(Enum): + ROW_MAJOR = "row_major" + COL_MAJOR = "col_major" + + def mma_major_mode(self): + return ( + tcgen05.OperandMajorMode.K + if self == LayoutEnum.ROW_MAJOR + else tcgen05.OperandMajorMode.MN + ) + + def sm90_mma_major_mode(self): + return ( + warpgroup.OperandMajorMode.K + if self == LayoutEnum.ROW_MAJOR + else warpgroup.OperandMajorMode.MN + ) + + def is_n_major_c(self): + return self == LayoutEnum.ROW_MAJOR + + def is_m_major_c(self): + return self == LayoutEnum.COL_MAJOR + + @staticmethod + def from_tensor(tensor: cute.Tensor) -> "LayoutEnum": + ret = None + if tensor.leading_dim == 1: + ret = LayoutEnum.ROW_MAJOR + elif tensor.leading_dim == 0: + ret = LayoutEnum.COL_MAJOR + else: + raise ValueError(f"Invalid leading dimension: {tensor.leading_dim}") + + return ret + + +__all__ = ["LayoutEnum"] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_allocator.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_allocator.py new file mode 100644 index 0000000000000000000000000000000000000000..2500c06e1808bc06db5decce88e8ebf7837f17d0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_allocator.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Type, Union, overload + +from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta, CutlassBaseDSL + +import cutlass.cute as cute +from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size + + +class SmemAllocator: + """A class for managing shared memory allocation on GPU. + + This class manages a chunk of shared memory and provides APIs for sub-allocation + inside the chunk. + + :ivar _base: The current base address of the shared memory as an i8 typed dynamic value. + :type _base: cute.Pointer + :ivar _allocated_bytes: The total number of bytes allocated in shared memory. + :type _allocated_bytes: int + + .. note:: + This class is responsible for managing the allocation of tensors in shared memory. + The base pointer is aligned to 1024 bytes upon initialization. + """ + + def __init__(self): + """Initialize the SmemAllocator instance. + + Creates a dynamic shared memory base pointer of type i8, aligned to 1024 bytes. + """ + self._base = get_dyn_smem(Int8, alignment=1024) + self._allocated_bytes = 0 + CutlassBaseDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes) + + @overload + def allocate(self, size_or_type: int, byte_alignment: int) -> cute.Pointer: ... + + @overload + def allocate( + self, size_or_type: cute.struct, byte_alignment: int + ) -> cute.Pointer: ... + + def allocate(self, size_or_type, byte_alignment: int = 1) -> cute.Pointer: + """Allocate a block of memory with specified size and alignment. + + This method adjusts the base pointer to ensure proper alignment and updates + the internal state to track allocated memory. + + :param size_or_type: The number of bytes to allocate or a struct class + :type size_or_type: Union[int, cute.struct] + :param byte_alignment: The byte alignment requirement, defaults to 1 (no alignment) + :type byte_alignment: int, optional + :return: Pointer to the start of the allocated memory block or struct instance + :rtype: cute.Pointer + :raises ValueError: If size is negative or alignment is less than 1 + :raises RuntimeError: If allocation would exceed available shared memory + """ + if isinstance(size_or_type, cute.struct): + alignment = max(byte_alignment, size_or_type.__alignof__()) + base_ptr = self.allocate(size_or_type.__sizeof__(), alignment) + return size_or_type(base_ptr) + + num_bytes = size_or_type + if num_bytes < 0: + raise ValueError("num_bytes must be non-negative") + if byte_alignment < 1: + raise ValueError("byte_alignment must be at least 1") + + self._base = self._base.align(byte_alignment) + ptr = self._base + self._base += num_bytes + if self._allocated_bytes % byte_alignment != 0: + self._allocated_bytes += ( + byte_alignment - self._allocated_bytes % byte_alignment + ) + self._allocated_bytes += num_bytes + + # Check bounds against available dynamic shared memory + cute.testing.assert_( + self._allocated_bytes <= get_dyn_smem_size(), + f"Allocation failed: shared memory allocation exceeds available memory set in kernel launch. " + f"Allocated bytes: {self._allocated_bytes} bytes. " + f"Please reduce the allocation or set a larger smem size in kernel launch.", + ) + return ptr + + def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1): + """Allocate an array of elements in shared memory. + + :param element_type: The type of elements to allocate + :type element_type: Type[Numeric] + :param num_elems: Number of elements to allocate, defaults to 1 + :type num_elems: int, optional + :return: Pointer to the start of the allocated array + :rtype: cute.Pointer + :raises ValueError: If num_elems is less than 1 + :raises TypeError: If element_type is not a Numeric type + """ + if num_elems < 1: + raise ValueError("num_elems must be at least 1") + if not isinstance(element_type, NumericMeta): + raise TypeError( + f"value_ty must be a type of Numeric, but got {element_type}" + ) + + ptr = self.allocate( + element_type.width // 8 * num_elems, element_type.width // 8 + ) + + return cute.recast_ptr(ptr, dtype=element_type) + + def allocate_tensor( + self, + element_type: Type[Numeric], + layout: Union[int, cute.Layout, cute.ComposedLayout], + byte_alignment: int = 1, + swizzle: cute.Swizzle = None, + ): + """Allocate a tensor in shared memory. + + :param element_type: The type of elements in the tensor + :type element_type: Type[Numeric] + :param layout: The layout specification for the tensor + :type layout: Union[int, cute.Layout, cute.ComposedLayout] + :param byte_alignment: The byte alignment requirement, defaults to 1 + :type byte_alignment: int, optional + :param swizzle: Swizzle for position-dependent swizzling, defaults to None + :type swizzle: cute.Swizzle, optional + :return: The allocated tensor with specified properties + :rtype: cute.Tensor + :raises TypeError: If element_type is not a Numeric type or if swizzle conflicts with layout + :raises ValueError: If allocation is not byte-aligned + :raises NotImplementedError: If dynamic layout is specified + """ + if not isinstance(element_type, NumericMeta): + raise TypeError( + f"value_ty must be a type of Numeric, but got {element_type}" + ) + + if ( + isinstance(layout, cute.ComposedLayout) + and isinstance(layout.inner, cute.Swizzle) + ) and (swizzle is not None): + raise TypeError( + f"Invalid tensor type: cannot be both iterator swizzle (PDSL) and swizzle layout(PISL) at the same time." + ) + + if isinstance(layout, int): + layout = cute.make_layout(layout) + + profile = layout(0) + if isinstance(profile, tuple): + raise TypeError( + f"cannot allocate a shared memory tensor with a non-integer iterator" + ) + + if not cute.is_static(layout.type): + raise NotImplementedError(f"dynamic layout is not supported: {layout.type}") + + # At least align the allocation to the natural alignment given by the element type + if element_type.width // 8 > byte_alignment: + byte_alignment = element_type.width // 8 + + # Relevant only for sub-byte data types: verify that the entire allocation is byte-aligned + cosize_in_bits = cute.cosize(layout) * element_type.width + assert isinstance(cosize_in_bits, int) + if cosize_in_bits % 8 != 0: + raise ValueError("invalid allocation that is not byte-aligned") + + num_bytes = cosize_in_bits // 8 + ptr = self.allocate(num_bytes, byte_alignment) + ptr = cute.recast_ptr(ptr, swizzle, dtype=element_type) + res = cute.make_tensor(ptr, layout) + return res diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_capacity.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_capacity.py new file mode 100644 index 0000000000000000000000000000000000000000..87ddb990436caf8135a849b3a37bf52632eed2fc --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_capacity.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + + +SMEM_CAPACITY_MAP = { + "sm_120": (100 - 1) * 1024, + "sm_100": (228 - 1) * 1024, + "sm_90": (228 - 1) * 1024, + "sm_80": (164 - 1) * 1024, + "sm_86": (100 - 1) * 1024, + "sm_89": (100 - 1) * 1024, +} + + +def get_smem_capacity_in_bytes(compute_capability: str) -> int: + if compute_capability not in SMEM_CAPACITY_MAP: + raise ValueError(f"Unsupported compute capability: {compute_capability}") + return SMEM_CAPACITY_MAP[compute_capability] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..2873244d7cce9d8072f1fa71bbba1762022631b9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py @@ -0,0 +1,386 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Tuple + +from cutlass.cutlass_dsl import ( + Boolean, + Integer, + Int32, + min, + extract_mlir_values, + new_from_mlir_values, + dsl_user_op, +) +from cutlass._mlir import ir +import cutlass.cute as cute + +############################################################################## +# Static persistent tile scheduler +############################################################################## + + +class WorkTileInfo: + """A class to represent information about a work tile. + + :ivar tile_idx: The index of the tile. + :type tile_idx: cute.Coord + :ivar is_valid_tile: Whether the tile is valid. + :type is_valid_tile: Boolean + """ + + def __init__(self, tile_idx: cute.Coord, is_valid_tile: Boolean): + self._tile_idx = tile_idx + self._is_valid_tile = Boolean(is_valid_tile) + + def __extract_mlir_values__(self) -> list[ir.Value]: + values = extract_mlir_values(self.tile_idx) + values.extend(extract_mlir_values(self.is_valid_tile)) + return values + + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 4 + new_tile_idx = new_from_mlir_values(self._tile_idx, values[:-1]) + new_is_valid_tile = new_from_mlir_values(self._is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + @property + def is_valid_tile(self) -> Boolean: + """Check latest tile returned by the scheduler is valid or not. Any scheduling + requests after all tasks completed will return an invalid tile. + + :return: The validity of the tile. + :rtype: Boolean + """ + return self._is_valid_tile + + @property + def tile_idx(self) -> cute.Coord: + """ + Get the index of the tile. + + :return: The index of the tile. + :rtype: cute.Coord + """ + return self._tile_idx + + +class PersistentTileSchedulerParams: + """A class to represent parameters for a persistent tile scheduler. + + This class is designed to manage and compute the layout of clusters and tiles + in a batched gemm problem. + + :ivar cluster_shape_mn: Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1). + :type cluster_shape_mn: tuple + :ivar problem_layout_ncluster_mnl: Layout of the problem in terms of + number of clusters in (m, n, l) dimensions. + :type problem_layout_ncluster_mnl: cute.Layout + """ + + def __init__( + self, + problem_shape_ntile_mnl: cute.Shape, + cluster_shape_mnk: cute.Shape, + *, + loc=None, + ip=None, + ): + """ + Initializes the PersistentTileSchedulerParams with the given parameters. + + :param problem_shape_ntile_mnl: The shape of the problem in terms of + number of CTA (Cooperative Thread Array) in (m, n, l) dimensions. + :type problem_shape_ntile_mnl: cute.Shape + :param cluster_shape_mnk: The shape of the cluster in (m, n) dimensions. + :type cluster_shape_mnk: cute.Shape + + :raises ValueError: If cluster_shape_k is not 1. + """ + + if cluster_shape_mnk[2] != 1: + raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") + + self.problem_shape_ntile_mnl = problem_shape_ntile_mnl + # cluster_shape_mnk is kept for reconstruction + self._cluster_shape_mnk = cluster_shape_mnk + self.cluster_shape_mn = cluster_shape_mnk[:2] + self._loc = loc + + # By default, we follow m major (col-major) raster order, so make a col-major layout + self.problem_layout_ncluster_mnl = cute.make_layout( + cute.ceil_div( + self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.problem_shape_ntile_mnl, self._cluster_shape_mnk]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.problem_shape_ntile_mnl, self._cluster_shape_mnk], self._values_pos + ): + obj_list.append(new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return PersistentTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + @dsl_user_op + def get_grid_shape( + self, max_active_clusters: Int32, *, loc=None, ip=None + ) -> Tuple[Integer, Integer, Integer]: + """ + Computes the grid shape based on the maximum active clusters allowed. + + :param max_active_clusters: The maximum number of active clusters that + can run in one wave. + :type max_active_clusters: Int32 + + :return: A tuple containing the grid shape in (m, n, persistent_clusters). + - m: self.cluster_shape_m. + - n: self.cluster_shape_n. + - persistent_clusters: Number of persistent clusters that can run. + """ + + # Total ctas in problem size + num_ctas_mnl = tuple( + x * y + for x, y in zip( + self.problem_layout_ncluster_mnl.shape, self.cluster_shape_mn + ) + ) + (self.problem_layout_ncluster_mnl.shape[2],) + + num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip) + + num_ctas_per_cluster = cute.size(self.cluster_shape_mn, loc=loc, ip=ip) + # Total ctas that can run in one wave + num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster + + num_persistent_ctas = min(num_ctas_in_problem, num_ctas_per_wave) + num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster + + return (*self.cluster_shape_mn, num_persistent_clusters) + + +class StaticPersistentTileScheduler: + """A scheduler for static persistent tile execution in CUTLASS/CuTe kernels. + + :ivar params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl + :type params: PersistentTileSchedulerParams + :ivar num_persistent_clusters: Number of persistent clusters that can be launched + :type num_persistent_clusters: Int32 + :ivar cta_id_in_cluster: ID of the CTA within its cluster + :type cta_id_in_cluster: cute.Coord + :ivar _num_tiles_executed: Counter for executed tiles + :type _num_tiles_executed: Int32 + :ivar _current_work_linear_idx: Current cluster index + :type _current_work_linear_idx: Int32 + """ + + def __init__( + self, + params: PersistentTileSchedulerParams, + num_persistent_clusters: Int32, + current_work_linear_idx: Int32, + cta_id_in_cluster: cute.Coord, + num_tiles_executed: Int32, + ): + """ + Initializes the StaticPersistentTileScheduler with the given parameters. + + :param params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl. + :type params: PersistentTileSchedulerParams + :param num_persistent_clusters: Number of persistent clusters that can be launched. + :type num_persistent_clusters: Int32 + :param current_work_linear_idx: Current cluster index. + :type current_work_linear_idx: Int32 + :param cta_id_in_cluster: ID of the CTA within its cluster. + :type cta_id_in_cluster: cute.Coord + :param num_tiles_executed: Counter for executed tiles. + :type num_tiles_executed: Int32 + """ + self.params = params + self.num_persistent_clusters = num_persistent_clusters + self._current_work_linear_idx = current_work_linear_idx + self.cta_id_in_cluster = cta_id_in_cluster + self._num_tiles_executed = num_tiles_executed + + def __extract_mlir_values__(self) -> list[ir.Value]: + values = extract_mlir_values(self.num_persistent_clusters) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self.cta_id_in_cluster)) + values.extend(extract_mlir_values(self._num_tiles_executed)) + return values + + def __new_from_mlir_values__( + self, values: list[ir.Value] + ) -> "StaticPersistentTileScheduler": + assert len(values) == 6 + new_num_persistent_clusters = new_from_mlir_values( + self.num_persistent_clusters, [values[0]] + ) + new_current_work_linear_idx = new_from_mlir_values( + self._current_work_linear_idx, [values[1]] + ) + new_cta_id_in_cluster = new_from_mlir_values( + self.cta_id_in_cluster, values[2:5] + ) + new_num_tiles_executed = new_from_mlir_values( + self._num_tiles_executed, [values[5]] + ) + return StaticPersistentTileScheduler( + self.params, + new_num_persistent_clusters, + new_current_work_linear_idx, + new_cta_id_in_cluster, + new_num_tiles_executed, + ) + + # called by host + @dsl_user_op + @staticmethod + def create( + params: PersistentTileSchedulerParams, + block_idx: Tuple[Integer, Integer, Integer], + grid_dim: Tuple[Integer, Integer, Integer], + *, + loc=None, + ip=None, + ): + """Initialize the static persistent tile scheduler. + + :param params: Parameters for the persistent + tile scheduler. + :type params: PersistentTileSchedulerParams + :param block_idx: The 3d block index in the format (bidx, bidy, bidz). + :type block_idx: Tuple[Integer, Integer, Integer] + :param grid_dim: The 3d grid dimensions for kernel launch. + :type grid_dim: Tuple[Integer, Integer, Integer] + + :return: A StaticPersistentTileScheduler object. + :rtype: StaticPersistentTileScheduler + """ + params = params + + # Calculate the number of persistent clusters by dividing the total grid size + # by the number of CTAs per cluster + num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size( + params.cluster_shape_mn, loc=loc, ip=ip + ) + + bidx, bidy, bidz = block_idx + + # Initialize workload index equals to the cluster index in the grid + current_work_linear_idx = Int32(bidz) + + # CTA id in the cluster + cta_id_in_cluster = ( + Int32(bidx % params.cluster_shape_mn[0]), + Int32(bidy % params.cluster_shape_mn[1]), + Int32(0), + ) + # Initialize number of tiles executed to zero + num_tiles_executed = Int32(0) + return StaticPersistentTileScheduler( + params, + num_persistent_clusters, + current_work_linear_idx, + cta_id_in_cluster, + num_tiles_executed, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: PersistentTileSchedulerParams, + max_active_clusters: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Integer, Integer, Integer]: + """Calculates the grid shape to be launched on GPU using problem shape, + threadblock shape, and active cluster size. + + :param params: Parameters for grid shape calculation. + :type params: PersistentTileSchedulerParams + :param max_active_clusters: Maximum active clusters allowed. + :type max_active_clusters: Int32 + + :return: The calculated 3d grid shape. + :rtype: Tuple[Integer, Integer, Integer] + """ + + return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip) + + # private method + def _get_current_work_for_linear_idx( + self, current_work_linear_idx: Int32, *, loc=None, ip=None + ) -> WorkTileInfo: + """Compute current tile coord given current_work_linear_idx and cta_id_in_cluster. + + :param current_work_linear_idx: The linear index of the current work. + :type current_work_linear_idx: Int32 + + :return: An object containing information about the current tile coordinates + and validity status. + :rtype: WorkTileInfo + """ + + is_valid = current_work_linear_idx < cute.size( + self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip + ) + + cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_hier_coord( + current_work_linear_idx, loc=loc, ip=ip + ) + + # cur_tile_coord is a tuple of i32 values + cur_tile_coord = tuple( + Int32(x) * Int32(z) + Int32(y) + for x, y, z in zip( + cur_cluster_coord, + self.cta_id_in_cluster, + (*self.params.cluster_shape_mn, Int32(1)), + ) + ) + + return WorkTileInfo(cur_tile_coord, is_valid) + + @dsl_user_op + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + return self._get_current_work_for_linear_idx( + self._current_work_linear_idx, loc=loc, ip=ip + ) + + @dsl_user_op + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + return self.get_current_work(loc=loc, ip=ip) + + @dsl_user_op + def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None): + self._current_work_linear_idx += Int32(advance_count) * Int32( + self.num_persistent_clusters + ) + self._num_tiles_executed += Int32(1) + + @property + def num_tiles_executed(self) -> Int32: + return self._num_tiles_executed + + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/tensormap_manager.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/tensormap_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..c6369c200e13ad280dfdecdb5cb4aa7ad081da4c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/tensormap_manager.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from dataclasses import dataclass +from enum import Enum, auto +from typing import Tuple + +from cutlass.cutlass_dsl import const_expr + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +import cutlass.cute as cute + + +class TensorMapUpdateMode(Enum): + """ + Enum class defining tensor map update modes. + + Modes: + GMEM: Update tensormap in global memory + SMEM: Load tensormap from global memory to shared memory, + update it in shared memory, then store back to global memory + """ + + GMEM = auto() # Update tensormap in global memory + SMEM = auto() # Update tensormap in shared memory + + +@dataclass(frozen=True) +class TensorMapManager: + """ + Manages TensorMap operations including initialization and updates. + Provides utilities to convert tensormap pointer to across different memory spaces. + """ + + tensormap_update_mode: TensorMapUpdateMode + bytes_per_tensormap: int + + # convert given cute.Pointer or cutlass.Int64 to a cute.Pointer to tensormap. + # address_space: the address space of the resulting tensormap pointer. It could be generic or gmem + def get_tensormap_ptr( + self, + ptr: cute.Pointer, + address_space=_cute_ir.AddressSpace.gmem, + ) -> cute.Pointer: + if address_space not in [ + _cute_ir.AddressSpace.gmem, + _cute_ir.AddressSpace.generic, + ]: + raise ValueError(f"Invalid address space: {address_space} for tensormap") + + gmem_ptr_i64 = ptr.toint().ir_value() + gmem_ptr_i64_align_ty = _cute_ir.ConstrainedIntType.get( + self.bytes_per_tensormap, gmem_ptr_i64.type.width + ) + gmem_ptr_i64_align = _cute_ir.assume(gmem_ptr_i64_align_ty, gmem_ptr_i64) + gmem_ptr_ty = _cute_ir.PtrType.get( + _cute_nvgpu_ir.TmaDescriptorTiledType.get(), + address_space, + self.bytes_per_tensormap, + ) + return _cute_ir.inttoptr(gmem_ptr_ty, gmem_ptr_i64_align) + + # init tensormap pointed by dst_ptr with the one inside copy_atom. + # dst_ptr should be pointing to a global memory location or a smem location + # warp_id specifies which warp to perform the initialization + @cute.jit + def init_tensormap_from_atom( + self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, warp_id: int + ) -> None: + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + if warp_idx == warp_id: + with cute.arch.elect_one(): + cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr) + cute.arch.sync_warp() + return + + # Perform a fence operation to ensure previous `init_tensormap_from_atom` calls have been completed + def fence_tensormap_initialization( + self, + ) -> None: + if self.tensormap_update_mode == TensorMapUpdateMode.GMEM: + cute.arch.fence_acq_rel_cta() + return + + # Perform a fence operation to ensure previous `update_tensormap` calls have been completed + def fence_tensormap_update( + self, + tensormap_ptr: cute.Pointer, + ) -> None: + cute.nvgpu.cpasync.fence_tma_desc_acquire(tensormap_ptr) + return + + @cute.jit + def update_tensormap( + self, + tensor_gmem: Tuple[cute.Tensor, ...], + tma_copy_atom: Tuple[cute.CopyAtom, ...], + tensormap_gmem_ptr: Tuple[cute.Pointer, ...], + warp_id: int, + tensormap_smem_ptr: Tuple[cute.Pointer, ...], + ) -> None: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # updates before touching tensormap in global memory + if warp_idx == warp_id: + if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): + for copy_atom, tensor, smem_ptr in zip( + tma_copy_atom, tensor_gmem, tensormap_smem_ptr + ): + cute.nvgpu.cpasync.update_tma_descriptor( + copy_atom, tensor, smem_ptr + ) + # wait until it's safe to update tensormap in global memory + with cute.arch.elect_one(): + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.sync_warp() + # updates to tensormap in global memory + if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): + for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr): + cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr) + else: + for copy_atom, tensor, gmem_ptr in zip( + tma_copy_atom, tensor_gmem, tensormap_gmem_ptr + ): + cute.nvgpu.cpasync.update_tma_descriptor( + copy_atom, tensor, gmem_ptr + ) + cute.arch.sync_warp() + cute.nvgpu.cpasync.fence_tma_desc_release() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06ea3f6f5f54b0b4f125c22504b06f41e8bf7697 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/__init__.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .cutlass import * + +from ..base_dsl.ast_helpers import ( + loop_selector, + if_selector, + if_executor, + while_selector, + while_executor, + range, + range_constexpr, + range_dynamic, + const_expr, + dynamic_expr, + assert_executor, + bool_cast, + compare_executor, + any_executor, + all_executor, + range_value_check, + range_perf_warning, + cf_symbol_check, + redirect_builtin_function, + copy_members, + get_locals_or_none, +) + +from ..base_dsl import * +from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values +from ..base_dsl.typing import _binary_op_type_promote +from ..base_dsl._mlir_helpers.gpu import * +from ..base_dsl._mlir_helpers.op import dsl_user_op +from ..base_dsl.runtime import * +from ..base_dsl.runtime import cuda as cuda_helpers +from ..base_dsl.compiler import compile +from ..base_dsl.runtime.jit_arg_adapters import * diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..1630c873c7a1be3e013f966ea153c904f2b776ff --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass.py @@ -0,0 +1,1696 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides a DSL for Cutlass Dialects. It also includes utils with +regarding to that dialect. +""" + +# Local module imports +from itertools import chain +from types import GenericAlias, SimpleNamespace, UnionType +from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef, Any +import functools +import pkgutil +from dataclasses import is_dataclass, fields +from collections.abc import Sequence +import builtins + +from ..base_dsl import * +from ..base_dsl import compiler +from ..base_dsl.dsl import is_dynamic_expression, extract_mlir_values +from ..base_dsl.typing import * +from ..base_dsl.typing import DynamicExpression, get_mlir_types +from ..base_dsl.runtime.jit_arg_adapters import is_arg_spec_constexpr + +from ..base_dsl.ast_helpers import const_expr + +# MLIR Imports +from cutlass._mlir import ir, execution_engine, passmanager +from cutlass._mlir.dialects import arith, func, gpu, scf, cute, gpu as cutlass_gpu +from cutlass._mlir.dialects._ods_common import ( + get_op_result_or_op_results as _get_op_result_or_op_results, +) +from cutlass._mlir.extras import types as T + +# Helpers +from ..base_dsl._mlir_helpers import arith as cutlass_arith +from ..base_dsl._mlir_helpers import lru_cache_ir + +from ..base_dsl.ast_helpers import ( + loop_selector, + executor, + if_selector, + if_executor, + while_selector, + while_executor, + assert_executor, + const_expr, + dynamic_expr, + bool_cast, + compare_executor, + any_executor, + all_executor, + range_value_check, + range_perf_warning, + cf_symbol_check, +) + +from .cutlass_ast_decorators import ( + _loop_execute_range_dynamic, + _if_execute_dynamic, + _while_execute_dynamic, +) + +from .tree_utils import ( + is_constexpr_field, + tree_flatten, + tree_unflatten, + PyTreeDef, + is_frozen_dataclass, + DSLTreeFlattenError, +) +from ..base_dsl.runtime.jit_arg_adapters import JitArgAdapterRegistry + + +# ============================================================================= +# Cutlass DSL Base Abstract Class +# ============================================================================= + + +# Return a ctype class that represents the in-memory layout expected +# for a CuTe hierarchical tuple type. +def get_sparse_tuple_ctype(dyn): + # When there is a single dynamic value, the sparse CuTe + # representation is a single integer. + if isinstance(dyn, int): + return ctypes.c_int32 + + # For zero or greater than 1 dynamic values, the tuple + # representation will be a struct with a field for each dynamic + # value. The representation is flattened, even for hierarchical CuTe + # profiles (although we are only dealing with depth 1 inputs here). + class TupleDescriptor(ctypes.Structure): + _fields_ = [(f"x{idx}", ctypes.c_int32) for idx in range(len(dyn))] + + def __str__(self): + return f"struct<{str(self._fields_)}>" + + return TupleDescriptor + + +def is_cute_algebra_type(arg_spec): + # Walk through the arg_spec to check if it's a cute algebra type + _cute_algebra_type_aliases = ( + "Shape", + "Stride", + "Coord", + "Tile", + "IntTuple", + ) + + origin = get_origin(arg_spec) + if origin is Union: + for sub_ty in get_args(arg_spec): + sub_origin = get_origin(sub_ty) + if sub_origin is Tuple or ( + type(sub_origin) is type and issubclass(sub_origin, tuple) + ): + tuple_arg0 = get_args(sub_ty)[0] + if isinstance( + tuple_arg0, ForwardRef + ) and tuple_arg0.__forward_arg__ in (_cute_algebra_type_aliases): + return True + return False + + +def _get_c_pointers_cutlass(obj): + """ + This is an extended version of `get_c_pointers` that supports dataclasses, SimpleNamespace, and dict. + """ + if hasattr(obj, "__c_pointers__"): + return obj.__c_pointers__() + elif isinstance(obj, (tuple, list)): + return list(chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj)) + elif isinstance(obj, SimpleNamespace): + return list( + chain.from_iterable( + _get_c_pointers_cutlass(x) for x in obj.__dict__.values() + ) + ) + elif isinstance(obj, dict): + return list( + chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj.values()) + ) + elif is_dataclass(obj): + return list( + chain.from_iterable( + _get_c_pointers_cutlass(getattr(obj, f.name)) + for f in fields(obj) + if not is_constexpr_field(f) + ) + ) + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in get_c_pointers to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + else: + # Try get adapter + adapter = JitArgAdapterRegistry.get_registered_adapter(type(obj)) + if adapter is not None: + return _get_c_pointers_cutlass(adapter(obj)) + return [] + + +class CutlassBaseDSL(BaseDSL): + """This abstract class provides a DSL for Cutlass.""" + + def __init__( + self, + name: str, + compiler_provider: Any, + pass_sm_arch_name: str, + device_compilation_only: bool = False, + preprocess: bool = False, + ): + super().__init__( + name=name, + dsl_package_name=["cutlass"], + compiler_provider=compiler_provider, + pass_sm_arch_name=pass_sm_arch_name, + device_compilation_only=device_compilation_only, + preprocess=preprocess, + ) + self._smem_usage_tracker: tuple = None + + # this method is not useful for cutlass_dsl, so we only provide a dummy implementation. + def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: + return False + + # this method is not useful for cutlass_dsl, so we only provide a dummy implementation. + def _handle_tensor_descriptor( + self, maybe_tensor, arg_name: str, need_gpu_memory: bool + ) -> Any: + return False + + def _build_gpu_module(self, attrs): + self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels")) + with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])): + pass + + for attr_name in attrs: + self.gpu_module.attributes[attr_name] = ir.Attribute.parse(attrs[attr_name]) + + def _get_pipeline(self, pipeline): + pipeline = super()._get_pipeline(pipeline) + if pipeline == None: + # cubin format is required to be cubin as we launch cuda module at python level. + return ( + "builtin.module(cute-to-nvvm{cubin-format=bin " + + self.compile_options.to_str() + + "})" + ) + + return pipeline + + def preprocess_pipeline(self, pipeline, arch) -> str: + pipeline = super().preprocess_pipeline(pipeline, arch) + pipeline = pipeline.rstrip(")") + ",external-kernel-for-gpu-launch)" + return pipeline + + def _enter_gpu_module(self): + return ir.InsertionPoint(self.gpu_module.bodyRegion.blocks[0]) + + def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict: + assert isinstance( + config, BaseDSL.LaunchConfig + ), f"Expect LaunchConfig for @kernel, but got {type(config)}" + + ret = {} + # generate launch bound attr from LaunchConfig + max_threads = ", ".join(map(str, config.block)) + ret["nvvm.reqntid"] = ir.Attribute.parse(f"array") + # min_blocks_per_mp is optional for kernel + min_blocks = config.min_blocks_per_mp + if min_blocks > 0: + ret["nvvm.minctasm"] = ir.Attribute.parse(f"{min_blocks} : i32") + return ret + + @lru_cache(maxsize=1) + def get_version(self): + """ + Get the version of cutlass dsl, used for computing the hash key of the cache. + Including source python files and the shared library. + """ + dsl_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + # get the version hash of the cutlass shared library + version_hash = hashlib.sha256() + # update the version hash of the source python files + for lib in pkgutil.walk_packages([dsl_path], prefix="cutlass."): + try: + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + version_hash.update(f.read()) + except Exception: + raise DSLRuntimeError( + f"Failed to read module file {lib.name}. The file may not exist or may not be readable." + "Please re-install the package." + ) + try: + # update the version hash of the cutlass shared library + with open( + os.path.join(dsl_path, "_mlir/_mlir_libs/libCutlassIRPythonCAPI.so"), + "rb", + ) as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + version_hash.update(chunk) + except Exception: + raise DSLRuntimeError( + f"Failed to read the shared library file libCutlassIRPythonCAPI.so." + "The file may not exist or may not be readable." + "Please re-install the package." + ) + + return version_hash + + @staticmethod + def track_smem_allocator(allocator, callback): + """ + Tracks shared memory usage for kernel functions. + Find and set allocator to its parent dsl object. + """ + frame = inspect.currentframe().f_back + while frame: + obj = frame.f_locals.get("self", None) + if obj and isinstance(obj, CutlassBaseDSL): + obj._set_smem_tracking(allocator, callback) + return + frame = frame.f_back + warnings.warn("Cannot find parent dsl for allocator!", UserWarning) + + def _set_smem_tracking(self, allocator, callback): + # Registers an allocator and callback for current dsl + self._smem_usage_tracker = (allocator, callback) + + def _reset_smem_tracking(self): + # Clear an allocator and callback for current dsl + self._smem_usage_tracker = None + + def _get_smem_usage(self) -> int: + # Treat final allocated bytes of allocator as smem usage + if not self._smem_usage_tracker: + return 0 + allocator, callback = self._smem_usage_tracker + return callback(allocator) + + def _kernel_helper(self, funcBody, *args, **kwargs): + class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper): + def __init__(self, dsl: CutlassBaseDSL): + super().__init__() + self.dsl = dsl + self.dsl._reset_smem_tracking() + + def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None): + super().generate_func_op(arg_types, arg_attrs, kernel_name) + self.func_op = func.FuncOp( + kernel_name, ir.FunctionType.get(arg_types, []), loc=loc + ) + if arg_attrs is not None: + log().debug(arg_attrs) + self.func_op.arg_attrs = arg_attrs + return self.func_op + + def generate_func_ret_op(self): + return func.ReturnOp([]) + + def get_func_body_start(self): + assert self.func_op is not None, "Invalid func_op is not expected!" + return self.func_op.add_entry_block() + + def generate_launch_op(self, *args, **kwargs): + # Extract args and do validation + kernelSym = kwargs.get("kernelSym", None) + kernelOperands = kwargs.get("kernelOperands", None) + requiredArgs = kwargs.get("requiredArgs", None) + assert kernelSym is not None, "kernelSym being None is not expected!" + assert ( + requiredArgs is not None + ), "requiredArgs being None is not expected!" + assert ( + kernelOperands is not None + ), "kernelOperands being None is not expected!" + assert isinstance( + requiredArgs.config, BaseDSL.LaunchConfig + ), f"Expect LaunchConfig for @kernel, but got {type(requiredArgs.config)}" + + cfg = requiredArgs.config + + # Apply to grid, block, and cluster if present + cfg.grid = [to_index(size) for size in cfg.grid] + cfg.block = [to_index(size) for size in cfg.block] + if cfg.has_cluster: + cfg.cluster = [to_index(size) for size in cfg.cluster] + + smem_usage = self.dsl._get_smem_usage() + if any(not isinstance(x, int) for x in [cfg.smem, smem_usage]): + pass # cannot compare dynamic value inside kernel to launch op in py + elif cfg.auto_smem: + cfg.smem = smem_usage + elif smem_usage > cfg.smem: + warnings.warn( + f"Potential error: specified kernel launch smem bytes " + f"({cfg.smem}) is smaller than kernel usage ({smem_usage})!", + UserWarning, + ) + cfg.smem = const(cfg.smem) + + if not isinstance(cfg.async_deps, (list, tuple)): + cfg.async_deps = [cfg.async_deps] + is_async = len(cfg.async_deps) > 0 + token = gpu.launch_func( + gpu.AsyncTokenType.get() if is_async else None, + cfg.async_deps, + kernelSym, + *cfg.grid, + *cfg.block, + kernelOperands, + **dict( + zip( + ("cluster_size_x", "cluster_size_y", "cluster_size_z"), + tuple(cfg.cluster), + ) + ), + dynamic_shared_memory_size=cfg.smem, + ) + return token if is_async else None + + return KernelLauncher( + self, + lambda: _CutlassIrKernelGenHelper(self), + funcBody, + *args, + **kwargs, + ) + + def _preprocess_launch_config_args(self, args, kwargs): + """Helper to preprocess args and kwargs for LaunchConfig""" + if "stream" in kwargs: + kwargs["async_deps"] = kwargs.pop("stream") + + def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec): + """Mangle the name of the function to avoid conflicts with other functions""" + function_name = "cutlass_" + function_name + return super().mangle_name(function_name, args, args_spec) + + def _validate_arg(self, arg, arg_index, arg_name, arg_annotation): + """ + Validates if the arg is really of the annotated type. + """ + + if ( + is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None) + or arg_annotation is Any + ): + pass + else: + origin = get_origin(arg_annotation) + # Handle special case where annotation is Type[X] but arg is an actual type + if origin is type and isinstance(arg, type): + # Get the expected base type from Type[X] + expected_base = get_args(arg_annotation)[0] + if not issubclass(arg, expected_base): + return DSLRuntimeError( + f"expects argument #{arg_index+1} ({arg_name}) to be Type[{expected_base}], but got {arg}" + ) + # Handle Union types and generic types + elif origin is Union or isinstance(arg_annotation, UnionType): + # For Union types, check if arg matches any of the allowed types + allowed_types = get_args(arg_annotation) + if not any( + (ty is Any) + or (isinstance(ty, type) and isinstance(arg, ty)) + or (get_origin(ty) is tuple and isinstance(arg, tuple)) + for ty in allowed_types + ): + return DSLRuntimeError( + f"expects argument #{arg_index+1} ({arg_name}) to be one of {allowed_types}, but got {type(arg)}" + ) + elif isinstance(arg_annotation, type): + # Handle simple type annotations + if not isinstance(arg, arg_annotation) and arg is not None: + return DSLRuntimeError( + f"expects argument #{arg_index+1} ({arg_name}) to be {arg_annotation}, but got {type(arg)}" + ) + # Everything looks good if we are here + return None + + def _generate_jit_func_args_for_known_types( + self, + func, + arg, + arg_name, + arg_spec, + arg_index, + *, + is_host=True, + ): + jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], [] + default_attr = ir.DictAttr.get({}) + + ( + jit_exec_arg, + jit_arg_type, + jit_arg_attr, + ) = super()._generate_jit_func_args_for_known_types( + func, arg, arg_name, arg_spec, arg_index, is_host=is_host + ) + + if jit_arg_type is not None and len(jit_arg_type) == 0: + # Handle DSL specific types + if is_cute_algebra_type(arg_spec): + dyn_vals = extract_mlir_values(arg) + if dyn_vals: + # Handle dynamic types + jit_arg_type.extend([v.type for v in dyn_vals]) + jit_arg_attr.extend([default_attr] * len(dyn_vals)) + jit_exec_arg.extend(get_c_pointers(arg) if is_host else dyn_vals) + else: + jit_exec_arg = jit_arg_type = jit_arg_attr = None + elif not hasattr(arg, "__extract_mlir_values__") and not hasattr( + arg, "__new_from_mlir_values__" + ): + # Try tree_flatten + try: + dyn_vals, _ = tree_flatten(arg) + except DSLTreeFlattenError: + # If fails, just return the original arg + return jit_exec_arg, jit_arg_type, jit_arg_attr + + if dyn_vals: + jit_arg_type.extend([v.type for v in dyn_vals]) + jit_arg_attr.extend([default_attr] * len(dyn_vals)) + jit_exec_arg.extend( + _get_c_pointers_cutlass(arg) if is_host else dyn_vals + ) + else: + # If tree flatten yields empty list, treat it as a constexpr thing + # Like a dataclass with all fields are constexpr, or an empty tuple or list + jit_exec_arg = jit_arg_type = jit_arg_attr = None + return jit_exec_arg, jit_arg_type, jit_arg_attr + + def _generate_execution_arguments_for_known_types( + self, arg, arg_spec, arg_name, i, fop_args, iv_block_args + ): + ir_arg, iv_block_args = super()._generate_execution_arguments_for_known_types( + arg, arg_spec, arg_name, i, fop_args, iv_block_args + ) + if not ir_arg: + # Handling DSL specific types + if is_cute_algebra_type(arg_spec): + n_args = len(get_mlir_types(arg)) + blk_args = fop_args[iv_block_args : iv_block_args + n_args] + ir_arg.append(new_from_mlir_values(arg, blk_args)) + iv_block_args += n_args + elif not hasattr(arg, "__extract_mlir_values__") and not hasattr( + arg, "__new_from_mlir_values__" + ): + # Try tree_unflatten + try: + dyn_vals, tree_def = tree_flatten(arg) + block_args = fop_args[iv_block_args : iv_block_args + len(dyn_vals)] + ir_arg.append(tree_unflatten(tree_def, block_args)) + iv_block_args += len(dyn_vals) + except DSLTreeFlattenError: + return ir_arg, iv_block_args + + return ir_arg, iv_block_args + + +# ============================================================================= +# Cute DSL Class +# ============================================================================= + + +class CuTeDSL(CutlassBaseDSL): + """ + This is a concrete DSL subclass for the CuTe dialect. + """ + + def __init__(self): + name = "CUTE_DSL" + compiler_provider = compiler.Compiler(passmanager, execution_engine) + pass_sm_arch_name = "cubin-chip" + + super().__init__(name, compiler_provider, pass_sm_arch_name, preprocess=True) + + +# ============================================================================= +# KernelLauncher +# ============================================================================= + + +class KernelLauncher: + """ + This class is used to launch a kernel function. + Usage: + ```python + @cute.kernel + def kernel(arg1, arg2, ...): + ... + + @cute.jit + def launch_kernel(): + kernel(arg1, arg2, ...).launch(grid=[1, 1, 1], block=[1, 1, 1], ...) + # or + kernel(arg1, arg2, ...)(grid=[1, 1, 1], block=[1, 1, 1], ...) + ``` + """ + + def __init__( + self, + dsl: "CutlassBaseDSL", + kernelGenHelper: BaseDSL._KernelGenHelper, + funcBody, + *func_args, + **func_kwargs, + ): + self.dsl = dsl + self.kernelGenHelper = kernelGenHelper + self.funcBody = funcBody + self.func_args = func_args + self.func_kwargs = func_kwargs + + self._check_func_args(funcBody, *func_args, **func_kwargs) + + def _check_func_args(self, funcBody, *func_args, **func_kwargs): + # Get function signature + sig = inspect.signature(funcBody) + + # func_args and func_kwargs should match funcBody's signature, + # no extra or missing arguments. + try: + sig.bind(*func_args, **func_kwargs) + except TypeError as e: + raise DSLRuntimeError( + f"Failed to bind arguments to function `{funcBody.__name__}` with signature `{sig}`", + cause=e, + ) + + def smem_usage(self) -> int: + """ + Check smem usage for this kernel, only available after `launch` + """ + return self.dsl._get_smem_usage() + + def launch(self, *args, **kwargs): + self.dsl.frame = inspect.currentframe().f_back + self.dsl._preprocess_launch_config_args(args, kwargs) + config = self.dsl.LaunchConfig(*args, **kwargs) + + kernel_generator = self.dsl.kernel_launcher( + requiredArgs=["config"], + unitAttrNames=["gpu.kernel", "cute.kernel"], + valueAttrDict=self.dsl._generate_kernel_attrs(config), + kernelGenHelper=self.kernelGenHelper, + )(self.funcBody) + + ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config) + self.dsl.kernel_symbols.append(name) + self.dsl.frame = None + return ret.launch_op_ret + + def __call__(self, *args, **kwargs): + return self.launch(*args, **kwargs) + + +# ============================================================================= +# Utils +# ============================================================================= +def _filter_readonly_frozen_dataclass( + iter_args: List[Any], items_to_filter: List[Any], full_write_args_count: int +) -> List[Any]: + """ + Filter items based on whether corresponding iter_args are frozen dataclasses. + + This function filters items (which can be values or names) based on the same + logic: keep items if they correspond to full-write arguments (index < full_write_args_count) + or if the corresponding iter_arg is not a frozen dataclass. + + Args: + iter_args: List of arguments to check for frozen dataclass status + items_to_filter: List of items to filter (values or names) + full_write_args_count: Number of arguments that are always written (not read-only) + + Returns: + Filtered list of items + + Examples: + # Filter values (original remove_read_only_frozen_dataclass behavior) + filtered_values = _filter_readonly_frozen_dataclass(iter_args, iter_args, full_write_args_count) + + # Filter names (original filter_readonly_frozen_dataclass_names behavior) + filtered_names = _filter_readonly_frozen_dataclass(iter_args, iter_args_names, full_write_args_count) + """ + return [ + item + for i, item in enumerate(items_to_filter) + if i < full_write_args_count or not is_frozen_dataclass(iter_args[i]) + ] + + +def remove_read_only_frozen_dataclass( + iter_args: List[Any], full_write_args_count: int +) -> List[Any]: + """Filter out frozen dataclass arguments that are not full-write arguments.""" + return _filter_readonly_frozen_dataclass( + iter_args, iter_args, full_write_args_count + ) + + +def filter_readonly_frozen_dataclass_names( + iter_args: List[Any], iter_args_names: List[str], full_write_args_count: int +) -> List[str]: + """Filter names based on whether corresponding iter_args are frozen dataclasses.""" + return _filter_readonly_frozen_dataclass( + iter_args, iter_args_names, full_write_args_count + ) + + +def insert_read_only_frozen_dataclass( + iter_args: List[Any], original_iter_args: List[Any], full_write_args_count: int +) -> List[Any]: + """ + Insert read-only frozen dataclass arguments back into the iteration arguments. + + This function takes the new iteration arguments and the original arguments, + and preserves frozen dataclass instances from the original arguments while + using the new arguments for non-frozen dataclass instances. + + Args: + iter_args: New iteration arguments to use for non-frozen dataclass instances + original_iter_args: Original iteration arguments to preserve frozen dataclass instances from + full_write_args_count: Number of arguments that are always written (not read-only) + + Returns: + List of arguments with frozen dataclass instances preserved from original + """ + # Take full-write arguments from new iter_args + full_write_args = ( + iter_args[:full_write_args_count] if full_write_args_count > 0 else [] + ) + + # Process remaining arguments: preserve frozen dataclass from original, use new for others + remaining_original = original_iter_args[full_write_args_count:] + remaining_new = iter_args[full_write_args_count:] + + def process_remaining_arg(original_arg, new_arg_iter): + """Process a single remaining argument, preserving frozen dataclass if present""" + return original_arg if is_frozen_dataclass(original_arg) else next(new_arg_iter) + + # Use zip to pair original args with new args, then map the processing function + new_arg_iter = iter(remaining_new) + processed_remaining = [ + process_remaining_arg(orig_arg, new_arg_iter) for orig_arg in remaining_original + ] + + return full_write_args + processed_remaining + + +def unpack_to_irvalue( + mixed_values: List[Any], body_name: str, full_write_args_count: int +) -> Tuple[List[ir.Value], PyTreeDef]: + log().debug("===--- Values UNPack") + for idx, packed in enumerate(mixed_values): + log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed) + + try: + unpacked_values, treedef = tree_flatten( + remove_read_only_frozen_dataclass(mixed_values, full_write_args_count) + ) + except DSLTreeFlattenError as e: + raise DSLRuntimeError( + f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression.", + context={ + e.message: ( + f"All expressions within '{body_name}' must be dynamic expressions, " + "mixing Python objects and dynamic expressions is not supported. " + "The DSL failed to convert the Python object into dynamic expressions." + ) + }, + suggestion=( + f"Please ensure '{e.type_str}' implements the '{DynamicExpression.__name__}' or mark with `dataclass`, " + f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects." + ), + ) + + log().debug("------------------ ") + for idx, unpacked in enumerate(unpacked_values): + log().debug("[%d]: unpacked values: %s", idx, unpacked) + log().debug("treedef: %s", treedef) + log().debug("------------------ ") + + return unpacked_values, treedef + + +def pack_from_irvalue( + ir_values: List["ir.Value"], + pytree_def: PyTreeDef, + mixed_values: List[Any], + full_write_args_count: int, +) -> List[Any]: + """ + Packs MLIR values into a list of mixed values. + """ + log().debug("===--- Values Pack (%d)", len(ir_values)) + for idx, value in enumerate(ir_values): + log().debug("[%d]: will-packed: %s", idx, value) + log().debug("treedef: %s", pytree_def) + log().debug("------------------ ") + + unflattened = tree_unflatten(pytree_def, ir_values) + return insert_read_only_frozen_dataclass( + unflattened, mixed_values, full_write_args_count + ) + + +def to_index(value): + """Converts a value to an index, either by casting or coercing to int.""" + if is_dynamic_expression(value): + if isinstance(value, Numeric): + value = value.ir_value() + assert ir.IntegerType.isinstance( + value.type + ), f"expects integer type, but got {value.type}" + res = arith.index_cast(T.index(), value) + else: + res = const(int(value), ty=T.index()) + + return res + + +def _validate_iter_args_structure(iter_args, ir_values): + """ + Validates that iter_args structure contains the same number of atomic values + as there are IR values. + + Args: + iter_args: Original iteration arguments, possibly nested sequences + ir_values: Flattened MLIR values extracted from iter_args + + Returns: + bool: True if the number of atomic values in iter_args matches + the number of values in ir_values + """ + # Handle non-sequence case + if not isinstance(iter_args, (tuple, list, set)): + return not isinstance(ir_values, (tuple, list, set)) or len(ir_values) == 1 + + # If we have a sequence but ir_values isn't one, there's a mismatch + if not isinstance(ir_values, (tuple, list, set)): + return False + + # Count all non-sequence values recursively + def count_values(args): + if not isinstance(args, (tuple, list, set)): + return 1 + else: + return sum(count_values(arg) for arg in args) + + return count_values(iter_args) == len(ir_values) + + + +# ============================================================================= +# DSL implementation of Python Build-in Operators +# ============================================================================= + + +def _minmax(op, *args, loc=None, ip=None): + """Computes the minimum or maximum value from the provided arguments.""" + from ..base_dsl.typing import _binary_op, _binary_op_type_promote + + # AST Traversal doesn't support early exit in if executor + x = None + res = None + if len(args) == 1: + # Handle case for min([a, b, c, d, ..]) + if hasattr(args[0], "__iter__"): + x = op(*tuple(args[0])) + # Handle case for min(a) + else: + x = args[0] + # Handle case for min(a, b, c, ...) and min([x, y], [b]) and min(a, (x, y, z)) + elif len(args) > 1: + res, *xs = tuple(args) + for x in xs: + lhs = as_numeric(op(res, loc=loc, ip=ip)) + rhs = as_numeric(op(x, loc=loc, ip=ip)) + emitter = getattr(cutlass_arith, f"_{op.__name__}") + + lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool=True) + + if isinstance(lhs.value, cutlass_arith.ArithValue) and isinstance( + lhs, Integer + ): + lhs_val = lhs.value.with_signedness(lhs.signed) + else: + lhs_val = lhs.value + + if isinstance(rhs.value, cutlass_arith.ArithValue) and isinstance( + rhs, Integer + ): + rhs_val = rhs.value.with_signedness(rhs.signed) + else: + rhs_val = rhs.value + + res = res_type(emitter(lhs_val, rhs_val), loc=loc, ip=ip) + x = res + else: + raise DSLNotImplemented(f"{type(args)} is not supported") + return x + + +def min(*args, loc=None, ip=None): + """Computes the minimum value from the provided arguments. + + This function differs from Python's built-in min() in that the return type + is determined by the static types of the inputs, not their dynamic values. + + :param args: One or more values or iterables to find the minimum of + :type args: tuple + :param loc: Source location for MLIR operation tracking + :type loc: object, optional + :param ip: Insertion point for MLIR operation + :type ip: object, optional + :return: The minimum value among all inputs + :rtype: Numeric + :raises DSLNotImplemented: If the input type is not supported + + Supports multiple calling patterns: + + - min(a): Returns a + - min([a, b, c, ...]): Returns minimum of all elements in the iterable + - min(a, b, c, ...): Returns minimum of all arguments + - min([x, y], [b]): Returns minimum across all elements in all iterables + - min(a, (x, y, z)): Returns minimum across all elements + + Examples: + + .. code-block:: python + + # Find minimum of two values + result = min(x, y) + + # Find minimum of multiple values + result = min(a, b, c, d) + + # Find minimum of values in a list + values = [a, b, c, d] + result = min(values) + + # Find minimum across mixed arguments + result = min(x, [y, z]) + + Difference from Python's built-in min(): + + .. code-block:: python + + # In Python, the return type depends on the dynamic values: + a = 5 + b = 3.14 + result = min(a, b) # Returns 3.14 (float) + + # In this DSL implementation, the return type is determined statically: + a = Int32(5) + b = Float32(3.14) + result = min(a, b) # Return type is determined by the type of operands, not values + """ + return _minmax(min, *args, loc=loc, ip=ip) + + +def max(*args, loc=None, ip=None): + """Computes the maximum value from the provided arguments. + + This function differs from Python's built-in max() in that the return type + is determined by the static types of the inputs, not their dynamic values. + + :param args: One or more values or iterables to find the maximum of + :type args: tuple + :param loc: Source location for MLIR operation tracking + :type loc: object, optional + :param ip: Insertion point for MLIR operation + :type ip: object, optional + :return: The maximum value among all inputs + :rtype: Numeric + :raises DSLNotImplemented: If the input type is not supported + + Supports multiple calling patterns: + + - max(a): Returns a + - max([a, b, c, ...]): Returns maximum of all elements in the iterable + - max(a, b, c, ...): Returns maximum of all arguments + - max([x, y], [b]): Returns maximum across all elements in all iterables + - max(a, (x, y, z)): Returns maximum across all elements + + Examples: + + .. code-block:: python + + # Find maximum of two values + result = max(x, y) + + # Find maximum of multiple values + result = max(a, b, c, d) + + # Find maximum of values in a list + values = [a, b, c, d] + result = max(values) + + # Find maximum across mixed arguments + result = max(x, [y, z]) + + Difference from Python's built-in max(): + + .. code-block:: python + + # In Python, the return type depends on the dynamic values: + a = 5 + b = 3.14 + result = max(a, b) # Returns 5 (int) + + # In this DSL implementation, the return type is determined statically: + a = Int32(5) + b = Float32(3.14) + result = max(a, b) # Return type is determined by the type of operands, not values + """ + return _minmax(max, *args, loc=loc, ip=ip) + + +def and_(*args, loc=None, ip=None): + """AND operation for value in DSL numeric types. + + :param *args: One or more numeric values to AND together + :type *args: Numeric + :param loc: Source location for MLIR operation tracking + :type loc: object, optional + :param ip: Insertion point for MLIR operation + :type ip: object, optional + :return: The result of the logical AND operation + :rtype: Numeric + :raises ValueError: If no arguments are provided + + Supports multiple calling patterns: + + - and_(a): Returns a + - and_(a, b, c, ...): if a is truthy, returns and_(b, c, ...), otherwise returns a + + All arguments must be of the same type. + + Examples: + + .. code-block:: python + + # In Python, 'and' returns the second operand if the first is truthy, + # otherwise it returns the first operand + a = 5 + b = 3 + result = a and b # Returns 3 + + # In this DSL implementation, the behavior is similar but works with DSL types + a = Int32(5) + b = Int32(3) + result = and_(a, b) # Returns b + """ + if len(args) == 0: + raise ValueError("and_() requires at least one argument") + + if len(args) == 1: + return args[0] + + def and_op(lhs, rhs): + if not isinstance(lhs, (Numeric, cutlass_arith.ArithValue, int, float, bool)): + raise DSLNotImplemented(f"{type(lhs)} is not supported") + elif isinstance(lhs, (int, float, bool)) and isinstance( + rhs, (int, float, bool) + ): + return lhs and rhs + else: + return as_numeric(lhs).__dsl_and__(as_numeric(rhs)) + + return functools.reduce(and_op, args[1:], args[0]) + + +def or_(*args, loc=None, ip=None): + """Logical OR operation for DSL numeric types. + + :param *args: One or more numeric values to OR together + :type *args: Numeric + :param loc: Source location for MLIR operation tracking + :type loc: object, optional + :param ip: Insertion point for MLIR operation + :type ip: object, optional + :return: The result of the logical OR operation + :rtype: Numeric + :raises ValueError: If no arguments are provided + + Supports multiple calling patterns: + + - or_(a): Returns a + - or_(a, b, c, ...): if a is truthy, returns a, otherwise returns or_(b, c, ...) + + Examples: + + .. code-block:: python + + # In Python, 'or' returns the first operand if it's truthy, + # otherwise it returns the second operand + a = 5 + b = 3 + result = a or b # Returns 5 + + # In this DSL implementation, the behavior is similar but works with DSL types + a = Int32(5) + b = Int32(3) + result = or_(a, b) # Returns a + """ + if len(args) == 0: + raise ValueError("or_() requires at least one argument") + + if len(args) == 1: + return args[0] + + def or_op(lhs, rhs): + if not isinstance(lhs, (Numeric, cutlass_arith.ArithValue, int, float, bool)): + raise DSLNotImplemented(f"{type(lhs)} is not supported") + elif isinstance(lhs, (int, float, bool)) and isinstance( + rhs, (int, float, bool) + ): + return lhs or rhs + else: + return as_numeric(lhs).__dsl_or__(as_numeric(rhs)) + + return functools.reduce(or_op, args[1:], args[0]) + + +def all_(iterable): + """Logical AND operation for all elements in an iterable. + + Returns True if all elements in the iterable are truthy, otherwise False. + This is the DSL equivalent of Python's built-in all() function. + + :param iterable: An iterable containing values to check + :type iterable: Iterable + :return: True if all elements are truthy, False otherwise + :rtype: Boolean + + Examples: + + .. code-block:: python + + # Check if all values are non-zero + values = [Int32(1), Int32(2), Int32(3)] + result = all_(values) # Returns True + + # Check if all conditions are met + conditions = [a > 0, b < 10, c != 0] + result = all_(conditions) # Returns True if all conditions are met + """ + bool_iterable = [Boolean(i) for i in iterable] + return functools.reduce( + lambda lhs, rhs: lhs.__dsl_and__(rhs) if hasattr(lhs, "__dsl_and__") else lhs, + bool_iterable, + Boolean(True), + ) + + +def any_(iterable): + """Logical OR operation for any element in an iterable. + + Returns True if any element in the iterable is truthy, otherwise False. + This is the DSL equivalent of Python's built-in any() function. + + :param iterable: An iterable containing values to check + :type iterable: Iterable + :return: True if any element is truthy, False otherwise + :rtype: Boolean + + Examples: + + .. code-block:: python + + # Check if any value is non-zero + values = [Int32(0), Int32(0), Int32(3)] + result = any_(values) # Returns True + + # Check if any condition is met + conditions = [a > 10, b < 0, c != 0] + result = any_(conditions) # Returns True if any condition is met + """ + bool_iterable = [Boolean(i) for i in iterable] + return functools.reduce( + lambda lhs, rhs: lhs.__dsl_or__(rhs) if hasattr(lhs, "__dsl_or__") else lhs, + bool_iterable, + Boolean(False), + ) + + +# ============================================================================= +# Conditional Expression +# ============================================================================= + + +def select_(cond, if_value, else_value): + def _as_scalar(value): + if isinstance(value, list): + if len(value) == 1: + return value[0] + else: + raise DSLRuntimeError( + "Conditional expression must have exactly one value in all expressions" + ) + return value + + if not is_dynamic_expression(cond): + raise DSLRuntimeError("Conditional expression must be dynamic") + + # Extract MLIR values + cond = extract_mlir_values(cond) + if is_dynamic_expression(if_value): + if_value = extract_mlir_values(if_value) + else: + if_value = const(if_value) + if is_dynamic_expression(else_value): + else_value = extract_mlir_values(else_value) + else: + else_value = const(else_value) + + return arith.SelectOp( + _as_scalar(cond), _as_scalar(if_value), _as_scalar(else_value) + ).result + + +# ============================================================================= +# Terminator +# ============================================================================= + + +def yield_out(args=[], loc=None, ip=None): + """ + Generate a yield operation. It it used to return values from a loop, if-else, or while region. + """ + scf.yield_(extract_mlir_values(args), loc=loc, ip=ip) + + +# ============================================================================= +# For Loop +# ============================================================================= + + +class LoopUnroll(ir.Attribute): + def __init__(self, **kwargs): + valid_keys = set(["count", "full"]) + def to_mlir_attr(val): + if isinstance(val, bool): + return "true" if val else "false" + elif isinstance(val, int): + return f"{val} : i32" + else: + raise DSLNotImplemented(f"{type(val)} is not supported") + + cfg = {key: to_mlir_attr(kwargs[key]) for key in valid_keys if key in kwargs} + if kwargs.get("count", None) == 1: + cfg["disable"] = "true" + + unroll = "<" + ", ".join(f"{key} = {value}" for key, value in cfg.items()) + ">" + + super().__init__( + ir.Attribute.parse(f"#llvm.loop_annotation") + ) + + +def for_generate( + start, + stop=None, + step=None, + iter_args: Optional[Sequence[ir.Value]] = None, + *, + unroll: LoopUnroll = None, + prefetch_stages=None, + loc=None, + ip=None, +): + """ + scf.for with yield support + """ + + if step is None: + step = 1 + if stop is None: + stop = start + start = 0 + start = const(start) + params = [start, stop, step] + for i, p in enumerate(params): + if isinstance(p, int): + p = const(p) + elif isinstance(p, float): + raise DSLRuntimeError(f"{p=} must be int.") + elif isinstance(p, Integer): + p = p.ir_value() + params[i] = p + + start, stop, step = params + + def _createI32Attr(value): + if not isinstance(value, int): + raise DSLRuntimeError(f"value must be int.") + return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), value) + + ir_iter_args = extract_mlir_values(iter_args) if iter_args is not None else None + if not _validate_iter_args_structure(iter_args, ir_iter_args): + raise DSLRuntimeError("iter_args: Elements should be extractable as ir.Value.") + for_op = scf.ForOp(start, stop, step, ir_iter_args, loc=loc, ip=ip) + if unroll is not None: + for_op.attributes["loop_annotation"] = unroll + + if prefetch_stages is not None: + for_op.attributes["cutlass.pipelining"] = _createI32Attr(prefetch_stages) + + iv = for_op.induction_variable + new_results = new_from_mlir_values(iter_args, for_op.results) + new_iter_args = new_from_mlir_values(iter_args, for_op.inner_iter_args) + new_iter_args = () if new_iter_args is None else tuple(new_iter_args) + + with ir.InsertionPoint(for_op.body): + if len(new_iter_args) > 1: + yield iv, new_iter_args, new_results + elif len(new_iter_args) == 1: + yield iv, new_iter_args[0], new_results[0] + else: + yield iv + + +# ============================================================================= +# Logical Operators +# ============================================================================= + + +def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None): + """ + Logical Not + """ + res = None + # Handle Python bool first to prevent infinite recursion + if type(lhs) == bool: + res = lhs ^ True + elif hasattr(lhs, "__dsl_not__"): + res = lhs.__dsl_not__(loc=loc, ip=ip) + elif is_dynamic_expression(lhs): + # If lhs is MLIR value, compute not using xor + res = arith.XOrIOp(lhs, const(1, lhs.type)).result + else: + res = bool(lhs) ^ True + + return res + + +# ============================================================================= +# If/Else +# ============================================================================= + + +def if_generate( + cond: Boolean, + then_body: Callable, + else_body: Optional[Callable] = None, + input_args: List[DslType] = None, + return_types: List[DslType] = None, + *, + loc=None, + ip=None, +) -> List: + """ + Generate an IfOp with optional else branch and return values. + + Args: + cond: The condition expression + then_body: Function to execute in then branch + else_body: Optional function to execute in else branch + input_args: Arguments to pass to branch bodies + return_types: Expected return types for the operation + loc: Optional location information + ip: Optional insertion point + + Returns: + List of DSL typed results + """ + input_args = input_args or [] + mlir_return_types = [] + + # Validate and collect MLIR return types (if provided). + if return_types is not None: + for t in return_types: + if not isinstance(t, DslType): + raise DSLRuntimeError(f"{t=} must be a DslType.") + mlir_return_types.append(t.mlir_type) + + # Determine whether there's an else branch. + has_else = else_body is not None + + # Create the IfOp. + if_op = scf.IfOp( + Boolean(cond).ir_value(), mlir_return_types, hasElse=has_else, loc=loc, ip=ip + ) + + def _execute_and_yield_out(body, input_args): + yield_vals = body(*input_args) + if return_types is not None: + if not isinstance(yield_vals, Iterable): + # body only return single element + yield_vals = [yield_vals] + + yield_vals = [t(r) for t, r in zip(return_types, yield_vals)] + yield_out(yield_vals) + + # Generate the body for 'then'. + with ir.InsertionPoint(if_op.then_block): + _execute_and_yield_out(then_body, input_args) + + # Generate the body for 'else' if provided. + if has_else: + with ir.InsertionPoint(if_op.else_block): + _execute_and_yield_out(else_body, input_args) + + # Collect MLIR results. + mlir_results = _get_op_result_or_op_results(if_op) + + if not isinstance(mlir_results, list): + mlir_results = [mlir_results] + + # Wrap the results with their DSL types. + if return_types is None: + return [] + + vals = [t(r) for t, r in zip(return_types, mlir_results)] + + if len(vals) == 1: + return vals[0] + + return vals + + +# ============================================================================= +# While Loop +# ============================================================================= + + +class WhileLoopContext: + """ + Context manager for a dynamic while loop. + """ + + def __init__( + self, + inputs: Sequence[Union[ir.Value, Numeric]], + condition: Callable[[Sequence[ir.Value]], ir.Value], + *, + loc=None, + ip=None, + ): + # Keep original inputs and allow recover original type information + self.inputs = inputs + + self.input_ir_values = extract_mlir_values(inputs) + + if not _validate_iter_args_structure(inputs, self.input_ir_values): + raise DSLRuntimeError("inputs: Elements should be extractable as ir.Value.") + + self.condition = condition + self.input_ir_types = [i.type for i in self.input_ir_values] + self.while_op = scf.WhileOp( + self.input_ir_types, self.input_ir_values, loc=loc, ip=ip + ) + + self.before_region = self.while_op.before + self.after_region = self.while_op.after + + self.before_region.blocks.append(*self.input_ir_types) + self.before_block = self.before_region.blocks[0] + + self.after_region.blocks.append(*self.input_ir_types) + self.after_block = self.after_region.blocks[0] + + def __enter__(self): + with ir.InsertionPoint(self.before_block): + args = new_from_mlir_values(self.inputs, self.before_block.arguments) + cond = self.condition(*args) + cond_ir_val = extract_mlir_values(cond) + scf.ConditionOp(cond_ir_val[0], [*self.before_block.arguments]) + self.ipoint_op = ir.InsertionPoint(self.after_block) + self.ipoint_op.__enter__() + return new_from_mlir_values(self.inputs, self.after_block.arguments) + + def __exit__(self, exc_type, exc_value, traceback): + self.ipoint_op.__exit__(exc_type, exc_value, traceback) + + @property + def results(self): + return new_from_mlir_values(self.inputs, self.while_op.results_) + + +def while_generate( + inputs: Sequence[Union[ir.Value, Numeric]], + condition: Callable[[Sequence[Union[ir.Value, Numeric]]], Union[ir.Value, Numeric]], + *, + loc=None, + ip=None, +) -> WhileLoopContext: + """ + Generate a WhileLoopContext for a dynamic loop. + """ + return WhileLoopContext(inputs, condition, loc=loc, ip=ip) + + +def equal(lhs, rhs): + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return lhs == rhs + + # Both sequence + if isinstance(lhs, Sequence) and isinstance(rhs, Sequence): + # Short-circuit for unequal length + if len(lhs) != len(rhs): + return False + return all_(equal(l, r) for l, r in zip(lhs, rhs)) + return lhs == rhs + + +def not_equal(lhs, rhs): + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return lhs != rhs + + # Both sequence + if isinstance(lhs, Sequence) and isinstance(rhs, Sequence): + # Short-circuit for unequal length + if len(lhs) != len(rhs): + return True + return any_(not_equal(l, r) for l, r in zip(lhs, rhs)) + + if hasattr(lhs, "__ne__"): + return lhs != rhs + elif hasattr(rhs, "__ne__"): + return rhs != lhs + else: + return not_(equal(lhs, rhs)) + + +def in_(lhs, rhs): + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return lhs in rhs + + if not isinstance(rhs, Sequence): + raise DSLRuntimeError( + f"'in' not supported between instances of {type(lhs)} and {type(rhs)}" + ) + + return any_(equal(lhs, r) for r in rhs) + + +def _lte_gte(lhs, rhs, op): + def native_lte_gte(lhs, rhs, op): + match op: + case "<": + return lhs < rhs + case "<=": + if hasattr(lhs, "__le__"): + return lhs <= rhs + else: + return not_(lhs > rhs) + case ">": + return lhs > rhs + case ">=": + if hasattr(lhs, "__ge__"): + return lhs >= rhs + else: + return not_(lhs < rhs) + case _: + raise DSLRuntimeError(f"Unsupported comparison operator: {op}") + + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return native_lte_gte(lhs, rhs, op) + + # Both sequence, comparisons other than == and != do not allow mixing different types of sequences + if ( + isinstance(lhs, Sequence) + and isinstance(rhs, Sequence) + and type(lhs) == type(rhs) + ): + unequal_found = False + comp_results = [] + mask = [] + for l, r in zip(lhs, rhs): + is_equal = equal(l, r) + mask.append(not_(or_(is_equal, unequal_found))) + unequal_found = not_(is_equal) + comp_results.append(_lte_gte(l, r, op)) + + result = any_(and_(r, m) for r, m in zip(comp_results, mask)) + + if len(lhs) != len(rhs): + # Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types + # If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one + has_valid_mask = any_(mask) + match op: + case "<": + length_result = len(lhs) < len(rhs) + case ">": + length_result = len(lhs) > len(rhs) + case "<=": + length_result = len(lhs) <= len(rhs) + case ">=": + length_result = len(lhs) >= len(rhs) + if type(has_valid_mask) == bool: + return result if has_valid_mask else length_result + else: + return select_(has_valid_mask, result, length_result) + else: + if op in {"<=", ">="}: + # If no unequal, return True + return select_(unequal_found, result, True) + else: + return result + else: + return native_lte_gte(lhs, rhs, op) + + +def greater_than(lhs, rhs): + return _lte_gte(lhs, rhs, ">") + + +def greater_equal(lhs, rhs): + return _lte_gte(lhs, rhs, ">=") + + +def less_than(lhs, rhs): + return _lte_gte(lhs, rhs, "<") + + +def less_equal(lhs, rhs): + return _lte_gte(lhs, rhs, "<=") + + +def _compare_dispatch(lhs, rhs, op): + """ + Dispatches the comparison operation between lhs and rhs based on the given operator. + + :param lhs: The left-hand side operand for the comparison. + :param rhs: The right-hand side operand for the comparison. + :param op: The comparison operator as a string. Supported operators are: + - "is", "is not": Python identity comparisons. + - "in", "not in": Membership tests. + - "==", "!=": Equality and inequality. + - "<", ">", "<=", ">=": Relational comparisons. + :return: The result of the comparison, which may be a boolean or a DSL-specific type. + :raises DSLRuntimeError: If the operator is not supported. + """ + match op: + # 'is' and 'is not' are pure python operators + case "is": + return lhs is rhs + case "is not": + return lhs is not rhs + case "in": + return in_(lhs, rhs) + case "not in": + return not_(in_(lhs, rhs)) + case "==": + return equal(lhs, rhs) + case "!=": + return not_equal(lhs, rhs) + case "<": + return less_than(lhs, rhs) + case ">": + return greater_than(lhs, rhs) + case ">=": + return greater_equal(lhs, rhs) + case "<=": + return less_equal(lhs, rhs) + case _: + raise DSLRuntimeError(f"Unsupported comparison operator: {op}") + + +def _compare_executor(left, comparators, ops): + # Fast path for single comparison + if len(comparators) == 1: + return _compare_dispatch(left, comparators[0], ops[0]) + + # Chain comparison, dispatch in a loop + result = True + current = left + for comparator, op in zip(comparators, ops): + cmp_result = _compare_dispatch(current, comparator, op) + result = and_(result, cmp_result) + current = comparator + + return result + + +def _builtin_redirector(fcn): + if fcn == builtins.max: + return max + elif fcn == builtins.min: + return min + elif fcn == builtins.any: + return any_ + elif fcn == builtins.all: + return all_ + else: + raise DSLRuntimeError(f"Unsupported built-in function: {fcn}") + + +# ============================================================================= +# Set the AST decorator +# ============================================================================= + +# Set the DSL specific functions +executor.set_functions( + is_dynamic_expression=is_dynamic_expression, + loop_execute_range_dynamic=_loop_execute_range_dynamic, + if_dynamic=_if_execute_dynamic, + while_dynamic=_while_execute_dynamic, + compare_executor=_compare_executor, + any_executor=any_, + all_executor=all_, + builtin_redirector=_builtin_redirector, +) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b4d8953d69b4100871a496623f051d60ab2a8d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py @@ -0,0 +1,633 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import List, Tuple +from types import NoneType +from cutlass._mlir import ir +from cutlass._mlir.dialects import scf, arith +from cutlass._mlir.extras import types as T +from collections.abc import Sequence + +from ..base_dsl.dsl import is_dynamic_expression +from ..base_dsl.ast_helpers import * +from ..base_dsl.utils.logger import log +from ..base_dsl import typing as t +from ..base_dsl.typing import ( + Int32, + Float32, + Boolean, + Numeric, + get_mlir_types, + as_numeric, +) +from . import cutlass as cutlass_dsl +from .tree_utils import PyTreeDef, check_tree_equal + +# ============================================================================= +# AST Helpers +# ============================================================================= + + +class LoopUnroll(ir.Attribute): + def __init__(self, **kwargs): + valid_keys = set(["count", "full"]) + def to_mlir_attr(val): + if isinstance(val, bool): + return "true" if val else "false" + elif isinstance(val, int): + return f"{val} : i32" + else: + raise DSLNotImplemented(f"{type(val)} is not supported") + + cfg = {key: to_mlir_attr(kwargs[key]) for key in valid_keys if key in kwargs} + if kwargs.get("count", None) == 1: + cfg["disable"] = "true" + + unroll = "<" + ", ".join(f"{key} = {value}" for key, value in cfg.items()) + ">" + + super().__init__( + ir.Attribute.parse(f"#llvm.loop_annotation") + ) + + +class ScfGenerator: + """ + Encapsulates common scf dialect functionality: pack, unpack, and SCF execution. + """ + + def __init__(self): + pass + + @staticmethod + def _normalize_region_result_to_list(region_result: Any) -> List[Any]: + """ + Convert region_result to a list if it is not already a list + If region_result is a list, return it as is. + If region_result is None, return an empty list. + If region_result is not a list, return a list containing region_result as the only element. + """ + if region_result is None: + region_result_list = [] + elif not isinstance(region_result, list): + region_result_list = [region_result] + else: + region_result_list = region_result + return region_result_list + + @staticmethod + def _check_region_result(original_value, region_value, arg_name, op_type_name): + """ + Validate that a region result maintains the same type as the original value. + + This method checks for type consistency between the original value passed to a dynamic + SCF operation (like for, if, while) and the value returned from the operation's region. + + Args: + original_value: The value before entering the SCF operation region + region_value: The value returned from the SCF operation region + arg_name: Name of the argument being checked (for error reporting) + op_type_name: Type of SCF operation (e.g., 'for', 'if', 'while') for error reporting + + Raises: + DSLRuntimeError: If the region value has a different type than the original value. + The error includes suggestions for using compile-time control flow instead. + + Note: + This method performs relaxed type checking that allows inheritance relationships. + For example, a child class can be returned where a parent class was expected. + However, fundamental type changes (like None to non-None, different sequence types, + or different numeric types) are not allowed in dynamic SCF operations. + """ + + def get_type_name(value): + if isinstance(value, NoneType): + return "None" + elif isinstance(value, Sequence): + return f"{type(value).__name__}<{len(value)}>" + else: + return type(value).__name__ + + # Check for type mismatches + type_mismatch = False + old_type_name = None + new_type_name = None + + # Handle None type changes + if isinstance(original_value, NoneType) != isinstance(region_value, NoneType): + type_mismatch = True + old_type_name = get_type_name(original_value) + new_type_name = get_type_name(region_value) + # Handle sequence type/length changes + elif isinstance(original_value, Sequence) and isinstance( + region_value, Sequence + ): + if type(original_value) != type(region_value) or len(original_value) != len( + region_value + ): + type_mismatch = True + old_type_name = get_type_name(original_value) + new_type_name = get_type_name(region_value) + # Handle numeric type changes + elif isinstance( + original_value, (Numeric, ArithValue, ir.Value, int, float, bool) + ) or isinstance( + region_value, (Numeric, ArithValue, ir.Value, int, float, bool) + ): + try: + original_numeric = as_numeric(original_value) + region_numeric = as_numeric(region_value) + if original_numeric.dtype != region_numeric.dtype: + type_mismatch = True + old_type_name = original_numeric.dtype.__name__ + new_type_name = region_numeric.dtype.__name__ + except Exception: + pass + # Handle general type changes (relaxed for inheritance) + elif type(original_value) != type(region_value): + old_type = type(original_value) + new_type = type(region_value) + if not (issubclass(old_type, new_type) or issubclass(new_type, old_type)): + type_mismatch = True + old_type_name = old_type.__name__ + new_type_name = new_type.__name__ + + if type_mismatch: + raise DSLRuntimeError( + f"`{arg_name}` is {old_type_name} prior to this `{op_type_name}`, " + f"and update to {new_type_name} inside of this `{op_type_name}` is not supported.", + suggestion=( + f"Please avoid changing type inside a dynamic `{op_type_name}`, " + f"or change to compile-time control flow by marking this `{op_type_name}` with " + f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`." + ), + ) + + def scf_execute_dynamic( + self, + op_type_name: str, + mix_iter_args: List[Any], + full_write_args_count: int, + mix_iter_arg_names: List[str], + create_op_func: Callable[[List[ir.Value]], ir.Operation], + region_builders: List[ + Callable[ + [ + "ir.Operation", + List["ir.Value"], # block_args + List["ir.Value"], # dyn_yield_ops + PyTreeDef, + List[Any], + int, + ], + Any, + ] + ], + # block_term_op_builder[region_builder] = scf_op_builder + # e.g. scf.ConditionOp for while loop + block_term_op_builder: Dict[Callable, Callable] = {}, + ) -> Any: + # 1) Unpack + ir_values, pytree_def = cutlass_dsl.unpack_to_irvalue( + mix_iter_args, op_type_name, full_write_args_count + ) + # 2) Create the SCF op + op = create_op_func(ir_values) + log().debug("Generated scf.%s \n[%s]", op_type_name, op) + + # 3) Build the regions + for i, builder in enumerate(region_builders): + region = op.regions[i] + block = region.blocks[0] + with ir.InsertionPoint(block): + block_args = list(block.arguments) + region_result = builder( + op, + block_args, + ir_values, + pytree_def, + mix_iter_args, + full_write_args_count, + ) + + # Use custom terminator if provided for this builder, otherwise use default YieldOp + if builder in block_term_op_builder: + # Use the provided terminator generator + block_term_op_builder[builder](region_result, full_write_args_count) + else: + # Normalize region_result + region_result_list = ScfGenerator._normalize_region_result_to_list( + region_result + ) + # For standard yield op, check result + for arg, result, name in zip( + mix_iter_args, + region_result_list, + mix_iter_arg_names, + ): + ScfGenerator._check_region_result( + arg, result, name, op_type_name + ) + + # Default behavior - generate YieldOp + region_values, yield_pytree_def = cutlass_dsl.unpack_to_irvalue( + region_result_list, op_type_name, full_write_args_count + ) + + mismatch = check_tree_equal(pytree_def, yield_pytree_def) + if mismatch != -1: + # Get arg name + filterd_arg_names = ( + cutlass_dsl.filter_readonly_frozen_dataclass_names( + mix_iter_args, mix_iter_arg_names, full_write_args_count + ) + ) + + raise DSLRuntimeError( + f"`{filterd_arg_names[mismatch]}` is structured different after this `{op_type_name}`.", + suggestion=( + f"Please avoid changing type structure inside a dynamic `{op_type_name}`, " + f"or change to compile-time control flow by marking this `{op_type_name}` with " + f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`." + ), + ) + + scf.YieldOp(region_values) + + log().debug("Completed scf.%s \n[%s]", op_type_name, op) + + # 4) Pack final results + final_results = cutlass_dsl.pack_from_irvalue( + op.results, pytree_def, mix_iter_args, full_write_args_count + ) + + # 5) Return in a nice pattern + if not final_results: + return + if len(final_results) == 1: + return final_results[0] + return final_results + + +def _attr_const_check(attr, expected_type, attr_name): + # Use strict type equality to prevent `bool` being accepted where `int` is required. + if is_dynamic_expression(attr) or type(attr) is not expected_type: + raise DSLRuntimeError( + f"loop attribute `{attr_name}` must be a Python value of type `{expected_type.__name__}`, got `{type(attr).__name__}`." + ) + + +def _loop_execute_range_dynamic( + func: Callable, + start: Any, + stop: Any, + step: Any, + mix_iter_args: List[Any] = [], + full_write_args_count: int = 0, + mix_iter_arg_names: List[str] = [], + unroll: int = -1, + unroll_full: bool = False, + prefetch_stages: int = None, +): + """ + Example: build an scf.for with optional unroll, using our universal helper. + """ + scf_gen = ScfGenerator() + + def create_for_op(dyn_yield_ops: List[ir.Value]): + for d in dyn_yield_ops: + if not isinstance(d, ir.Value): + raise DSLRuntimeError( + f"Invalid dyn_yield_ops: {dyn_yield_ops} \n\tExpected ir.Value, got {type(d)}" + ) + + # Convert Python ints or values to IR constants if needed + start_ = t.as_numeric(start) + stop_ = t.as_numeric(stop) + step_ = t.as_numeric(step) + assert start_ is not t.Int32, "Start is required for scf.for" + assert stop_ is not t.Int32, "Stop is required for scf.for" + assert step_ is not t.Int32, "Step is required for scf.for" + start_ = start_.ir_value() + stop_ = stop_.ir_value() + step_ = step_.ir_value() + + # Attributes must be pure Python value, add a check + _attr_const_check(unroll, int, "unroll") + _attr_const_check(unroll_full, bool, "unroll_full") + + # Possibly attach unroll attributes + unroll_attr = None + if unroll_full: + unroll_attr = LoopUnroll(full=True) + elif unroll != -1: + unroll_attr = LoopUnroll(count=unroll) + log().debug("Unroll attribute: %s", unroll_attr) + + prefetch_stages_attr = None + if prefetch_stages is not None: + _attr_const_check(prefetch_stages, int, "prefetch_stages") + if prefetch_stages >= 0: + prefetch_stages_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), prefetch_stages + ) + else: + raise DSLRuntimeError( + f"loop attribute `prefetch_stages` must be non-negative, got `{prefetch_stages}`." + ) + log().debug("prefetch_stages attribute: %s", prefetch_stages_attr) + + log().debug( + "Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s", + start_, + type(start_), + stop_, + type(stop_), + step_, + type(step_), + ) + # Create scf.ForOp, passing iteration args if any + try: + if not dyn_yield_ops: + for_op = scf.ForOp(start_, stop_, step_) + else: + for_op = scf.ForOp(start_, stop_, step_, list(dyn_yield_ops)) + except Exception as e: + yield_ops = "\n".join( + f"\t\t{i} => {d} : type : {type(d)}" + for i, d in enumerate(dyn_yield_ops) + ) + raise DSLRuntimeError( + f"Failed to create scf.ForOp \n\t\tstart={start_}: type : {type(start_)}" + f"\n\t\tstop={stop_}: type : {type(stop_)}\n\t\tstep={step_}: type : {type(step_)}" + f", \n\tdyn_yield_ops:\n{yield_ops}" + ) from e + + if unroll_attr is not None: + for_op.attributes["loop_annotation"] = unroll_attr + + if prefetch_stages_attr is not None: + for_op.attributes["cutlass.pipelining"] = prefetch_stages_attr + + return for_op + + def for_body_builder( + op, + block_args, + _, + pytree_def, + mix_iter_args, + full_write_args_count, + ): + # scf.ForOp block_args are typically [induction_var, iter_args...] + # But MLIR also gives you op.induction_variable + iv = t.as_numeric(op.induction_variable) + log().debug( + "For body builder: %s block_args: %s full_write_args_count: %s", + iv, + block_args, + full_write_args_count, + ) + # block_args[1:] are iteration variables + func_args = [] + func_args.extend( + cutlass_dsl.pack_from_irvalue( + block_args[1:], pytree_def, mix_iter_args, full_write_args_count + ) + ) + if not func_args: + # No iteration arguments, or only the induction var + func(iv) + return [] # yield nothing + else: + updated_func_args = func(iv, *func_args) + return updated_func_args + + # Now call the universal SCF executor with a single region builder + return scf_gen.scf_execute_dynamic( + op_type_name="for", + mix_iter_args=mix_iter_args, + full_write_args_count=full_write_args_count, + mix_iter_arg_names=mix_iter_arg_names, + create_op_func=create_for_op, + region_builders=[for_body_builder], + ) + + +def _if_execute_dynamic( + pred: "ir.Value", + then_block: Callable, + else_block: Callable = None, + mix_yield_args: List[Any] = [], + full_write_args_count: int = 0, + mix_yield_arg_names: List[str] = [], + if_constexpr=None, # ignoring for brevity +): + """ + Build an scf.if with optional else, using our universal helper. + """ + scf_gen = ScfGenerator() + + def create_if_op(dyn_yield_ops: List[ir.Value]): + # Assume final result types match the dynamic yields + result_types = [arg.type for arg in dyn_yield_ops] + + pred_ = Boolean(pred) + + try: + if_op = scf.IfOp( + pred_.ir_value(), + hasElse=(else_block is not None), + results_=result_types, + ) + except Exception as e: + raise DSLRuntimeError( + f"Failed to create scf.IfOp \n\t\tpred={pred_}: type : {type(pred_)}" + ) from e + return if_op + + def then_builder( + if_op, + _, + dyn_yield_ops, + pytree_def, + mix_iter_args, + full_write_args_count, + ): + flat_args = [] + flat_args.extend( + cutlass_dsl.pack_from_irvalue( + dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count + ) + ) + return then_block(*flat_args) + + region_builders = [then_builder] + + if else_block is not None: + + def else_builder( + if_op, + _, + dyn_yield_ops, + pytree_def, + mix_iter_args, + full_write_args_count, + ): + flat_args = [] + flat_args.extend( + cutlass_dsl.pack_from_irvalue( + dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count + ) + ) + return else_block(*flat_args) + + region_builders.append(else_builder) + + return scf_gen.scf_execute_dynamic( + op_type_name="if", + mix_iter_args=mix_yield_args, + full_write_args_count=full_write_args_count, + mix_iter_arg_names=mix_yield_arg_names, + create_op_func=create_if_op, + region_builders=region_builders, + ) + + +def _while_execute_dynamic( + while_before_block: Callable, + while_after_block: Callable = None, + write_args=[], + full_write_args_count=0, + write_args_names=[], +): + """ + Create and return an SCF WhileOp for dynamic loops. + Generate the dynamic loop body using SCF WhileOp. + + Args: + while_before_block: Function that returns (condition, updated_values) + while_after_block: Function that returns updated values + write_args: Values that are updated in the loop + + See create_while_function in ast_preprocessor.py for details on the input structure. + """ + log().debug("_while_execute_dynamic") + while_op_type_name = "while" + scf_gen = ScfGenerator() + + def create_while_op(dyn_yield_ops: List[ir.Value]): + # Create the while operation with the types from yield_args + result_types = [arg.type for arg in dyn_yield_ops] + try: + while_op = scf.WhileOp(result_types, dyn_yield_ops) + while_op.before.blocks.append(*result_types) + while_op.after.blocks.append(*result_types) + log().debug("[%s]", while_op) + return while_op + except Exception as e: + yield_ops = "\n".join( + f"\t\t{i} => {d} : type : {type(d)}" + for i, d in enumerate(dyn_yield_ops) + ) + raise DSLRuntimeError( + f"Failed to create scf.WhileOp with yield_ops:\n{yield_ops}" + ) from e + + def before_block_builder( + op, + block_args, + _, + pytree_def, + mix_iter_args, + full_write_args_count, + ): + # Build the before (condition) block + flat_args = [] + flat_args.extend( + cutlass_dsl.pack_from_irvalue( + block_args, pytree_def, mix_iter_args, full_write_args_count + ) + ) + + log().debug("before block args: %s", flat_args) + + cond, before_results = while_before_block(*flat_args) + + if not isinstance(before_results, (list, ir.OpResultList)): + before_results = [before_results] + + log().debug("cond [%s]", cond) + log().debug( + "before_results [%s]", + before_results, + ) + + return cond, before_results + + def before_block_terminator(cond_and_results, full_write_args_count): + # Generate a condition op instead of yield op + cond = cond_and_results[0] + before_result_list = ScfGenerator._normalize_region_result_to_list( + cond_and_results[1] + ) + ir_cond = as_numeric(cond).ir_value() + ir_results_list, pytree_def = cutlass_dsl.unpack_to_irvalue( + before_result_list, while_op_type_name, full_write_args_count + ) + log().debug( + "creating scf.ConditionOp with [%s], [%s]", + ir_cond, + ir_results_list, + ) + scf.ConditionOp(ir_cond, ir_results_list) + + def after_block_builder( + op, + block_args, + _, + pytree_def, + mix_iter_args, + full_write_args_count, + ): + # Build the after (body) block + flat_args = [] + flat_args.extend( + cutlass_dsl.pack_from_irvalue( + block_args, pytree_def, mix_iter_args, full_write_args_count + ) + ) + + log().debug("after block args: %s", flat_args) + + after_results = while_after_block(*flat_args) + + if not isinstance(after_results, (list, ir.OpResultList)): + after_results = [after_results] + + log().debug( + "after_results [%s]", + after_results, + ) + + return after_results + + # Call the universal SCF executor with two region builders + return scf_gen.scf_execute_dynamic( + op_type_name=while_op_type_name, + mix_iter_args=write_args, + full_write_args_count=full_write_args_count, + mix_iter_arg_names=write_args_names, + create_op_func=create_while_op, + region_builders=[before_block_builder, after_block_builder], + block_term_op_builder={ + before_block_builder: before_block_terminator + }, # Only customize the before block + ) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/tree_utils.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/tree_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..599b72ea5c6b1d378480ceeb1d43d14fd58b569d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/tree_utils.py @@ -0,0 +1,763 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin +import dataclasses +import itertools as it +from types import SimpleNamespace + +from ..base_dsl.typing import as_numeric, Numeric, Constexpr +from ..base_dsl._mlir_helpers.arith import ArithValue +from ..base_dsl.common import DSLBaseError +from .._mlir import ir + +# ============================================================================= +# Tree Utils +# ============================================================================= + + +class DSLTreeFlattenError(DSLBaseError): + """Exception raised when tree flattening fails due to unsupported types.""" + + def __init__(self, msg: str, type_str: str): + super().__init__(msg) + self.type_str = type_str + + +def unzip2(pairs: Iterable[tuple[Any, Any]]) -> tuple[list[Any], list[Any]]: + """Unzip a sequence of pairs into two lists.""" + lst1, lst2 = [], [] + for x1, x2 in pairs: + lst1.append(x1) + lst2.append(x2) + return lst1, lst2 + + +def get_fully_qualified_class_name(x: Any) -> str: + """ + Get the fully qualified class name of an object. + + Args: + x: Any object + + Returns: + str: Fully qualified class name in format 'module.class_name' + + Example: + >>> get_fully_qualified_class_name([1, 2, 3]) + 'builtins.list' + """ + return f"{x.__class__.__module__}.{x.__class__.__qualname__}" + + +def is_frozen_dataclass(obj_or_cls: Any) -> bool: + """ + Check if an object or class is a frozen dataclass. + + Args: + obj_or_cls: Either a dataclass instance or class + + Returns: + bool: True if the object/class is a dataclass declared with frozen=True, + False otherwise + + Example: + >>> from dataclasses import dataclass + >>> @dataclass(frozen=True) + ... class Point: + ... x: int + ... y: int + >>> is_frozen_dataclass(Point) + True + >>> is_frozen_dataclass(Point(1, 2)) + True + """ + cls = obj_or_cls if isinstance(obj_or_cls, type) else obj_or_cls.__class__ + + return ( + dataclasses.is_dataclass(cls) + and getattr(cls, "__dataclass_params__", None) is not None + and cls.__dataclass_params__.frozen + ) + + +def is_dynamic_expression(x: Any) -> bool: + """ + Check if an object implements the DynamicExpression protocol. + + Objects implementing this protocol must have both `__extract_mlir_values__` + and `__new_from_mlir_values__` methods. + + Args: + x: Any object to check + + Returns: + bool: True if the object implements the DynamicExpression protocol, + False otherwise + """ + return all( + hasattr(x, attr) + for attr in ("__extract_mlir_values__", "__new_from_mlir_values__") + ) + + +def is_constexpr_field(field: dataclasses.Field) -> bool: + """ + Check if a field is a constexpr field. + """ + if field.type is Constexpr: + return True + elif get_origin(field.type) is Constexpr: + return True + return False + + +# ============================================================================= +# PyTreeDef +# ============================================================================= + +class NodeType(NamedTuple): + """ + Represents a node in a pytree structure. + + Attributes: + name: String representation of the node type + to_iterable: Function to convert node to iterable form + from_iterable: Function to reconstruct node from iterable form + """ + name: str + to_iterable: Callable + from_iterable: Callable + + +class PyTreeDef(NamedTuple): + """ + Represents the structure definition of a pytree. + + Attributes: + node_type: The type of this node + node_metadata: SimpleNamespace metadata associated with this node + child_treedefs: Tuple of child tree definitions + """ + node_type: NodeType + node_metadata: SimpleNamespace + child_treedefs: tuple["PyTreeDef", ...] + + +@dataclasses.dataclass(frozen=True) +class Leaf: + """ + Represents a leaf node in a pytree structure. + + Attributes: + is_numeric: Whether this leaf contains a `Numeric` value + is_none: Whether this leaf represents None + node_metadata: SimpleNamespace metadata associated with this leaf + ir_type_str: String representation of the IR type + """ + is_numeric: bool = False + is_none: bool = False + node_metadata: SimpleNamespace = None + ir_type_str: str = None + + +# ============================================================================= +# Default to_iterable and from_iterable +# ============================================================================= + + +def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]: + """ + Extract non-method, non-function attributes from a dataclass instance. + + Args: + x: A dataclass instance + + Returns: + tuple: (field_names, field_values) lists + """ + fields = [field.name for field in dataclasses.fields(x)] + + # If the dataclass has extra fields, raise an error + for k in x.__dict__.keys(): + if k not in fields: + raise DSLTreeFlattenError( + f"`{x}` has extra field `{k}`", + type_str=get_fully_qualified_class_name(x), + ) + + if not fields: + return [], [] + + # record constexpr fields + members = [] + constexpr_fields = [] + for field in dataclasses.fields(x): + if is_constexpr_field(field): + constexpr_fields.append(field.name) + fields.remove(field.name) + v = getattr(x, field.name) + if is_dynamic_expression(v): + raise DSLTreeFlattenError( + f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`", + type_str=get_fully_qualified_class_name(x), + ) + else: + members.append(getattr(x, field.name)) + + return fields, members, constexpr_fields + + +def default_dataclass_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: + """ + Convert a dataclass instance to iterable form for tree flattening. + + Extracts all non-method, non-function attributes that don't start with '__' + and returns them along with metadata about the dataclass. + + Args: + x: A dataclass instance + + Returns: + tuple: (metadata, members) where metadata contains type info and field names, + and members is the list of attribute values + """ + fields, members, constexpr_fields = extract_dataclass_members(x) + + metadata = SimpleNamespace( + type_str=get_fully_qualified_class_name(x), + fields=fields, + constexpr_fields=constexpr_fields, + original_obj=x, + ) + return metadata, members + + +def set_dataclass_attributes( + instance: Any, + fields: list[str], + values: Iterable[Any], + constexpr_fields: list[str], +) -> Any: + """ + Set attributes on a dataclass instance. + + Args: + instance: The dataclass instance + fields: List of field names + values: Iterable of field values + is_frozen: Whether the dataclass is frozen + + Returns: + The instance with attributes set + """ + if not fields: + return instance + + kwargs = dict(zip(fields, values)) + for field in constexpr_fields: + kwargs[field] = getattr(instance, field) + return dataclasses.replace(instance, **kwargs) + +def default_dataclass_from_iterable( + metadata: SimpleNamespace, children: Iterable[Any] +) -> Any: + """ + Reconstruct a dataclass instance from iterable form. + + Handles both regular and frozen dataclasses appropriately. + + Args: + metadata: Metadata containing type information and field names + children: Iterable of attribute values to reconstruct the instance + + Returns: + The reconstructed dataclass instance + """ + instance = metadata.original_obj + + new_instance = set_dataclass_attributes( + instance, metadata.fields, children, metadata.constexpr_fields + ) + metadata.original_obj = new_instance + return new_instance + + +def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: + """ + Convert a dynamic expression to iterable form. + + Uses the object's `__extract_mlir_values__` method to extract MLIR values. + + Args: + x: A dynamic expression object + + Returns: + tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression + and mlir_values are the extracted MLIR values + """ + return ( + SimpleNamespace(is_dynamic_expression=1, original_obj=x), + x.__extract_mlir_values__(), + ) + + +def dynamic_expression_from_iterable( + metadata: SimpleNamespace, children: Iterable[Any] +) -> Any: + """ + Reconstruct a dynamic expression from iterable form. + + Uses the object's `__new_from_mlir_values__` method to reconstruct from MLIR values. + + Args: + metadata: Metadata containing the original object + children: Iterable of MLIR values to reconstruct from + + Returns: + The reconstructed dynamic expression object + """ + return metadata.original_obj.__new_from_mlir_values__(list(children)) + + +def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: + """ + Convert a dict to iterable form. + """ + if isinstance(x, SimpleNamespace): + keys = list(x.__dict__.keys()) + values = list(x.__dict__.values()) + else: + keys = list(x.keys()) + values = list(x.values()) + + return ( + SimpleNamespace( + type_str=get_fully_qualified_class_name(x), original_obj=x, fields=keys + ), + values, + ) + + +def default_dict_from_iterable( + metadata: SimpleNamespace, children: Iterable[Any] +) -> Any: + """ + Reconstruct a dict from iterable form. + """ + instance = metadata.original_obj + fields = metadata.fields + is_simple_namespace = isinstance(instance, SimpleNamespace) + + for k, v in zip(fields, children): + if is_simple_namespace: + setattr(instance, k, v) + else: + instance[k] = v + + return instance + + +# ============================================================================= +# Register pytree nodes +# ============================================================================= + +_node_types: dict[type, NodeType] = {} + + +def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable) -> NodeType: + """ + Register a new node type for pytree operations. + + Args: + ty: The type to register + to_iter: Function to convert instances of this type to iterable form + from_iter: Function to reconstruct instances of this type from iterable form + + Returns: + NodeType: The created NodeType instance + """ + nt = NodeType(str(ty), to_iter, from_iter) + _node_types[ty] = nt + return nt + + +def register_default_node_types() -> None: + """Register default node types for pytree operations.""" + default_registrations = [ + ( + tuple, + lambda t: (SimpleNamespace(length=len(t)), list(t)), + lambda _, xs: tuple(xs), + ), + ( + list, + lambda l: (SimpleNamespace(length=len(l)), list(l)), + lambda _, xs: list(xs), + ), + ( + dict, + default_dict_to_iterable, + default_dict_from_iterable, + ), + ( + SimpleNamespace, + default_dict_to_iterable, + default_dict_from_iterable, + ), + ] + + for ty, to_iter, from_iter in default_registrations: + register_pytree_node(ty, to_iter, from_iter) + + +# Initialize default registrations +register_default_node_types() + + +# ============================================================================= +# tree_flatten and tree_unflatten +# ============================================================================= + +""" +Behavior of tree_flatten and tree_unflatten, for example: + +```python + a = (1, 2, 3) + b = MyClass(a=1, b =[1,2,3]) +``` + +yields the following tree: + +```python + tree_a = PyTreeDef(type = 'tuple', + metadata = {length = 3}, + children = [ + Leaf(type = int), + Leaf(type = int), + Leaf(type = int), + ], + ) + flattened_a = [1, 2, 3] + tree_b = PyTreeDef(type = 'MyClass', + metadata = {fields = ['a','b']}, + children = [ + PyTreeDef(type = `list`, + metadata = {length = 3}, + children = [ + Leaf(type=`int`), + Leaf(type=`int`), + Leaf(type=`int`), + ], + ), + Leaf(type=int), + ], + ) + flattened_b = [1, 1, 2, 3] +``` + +Passing the flattened values and PyTreeDef to tree_unflatten to reconstruct the original structure. + +``` python + unflattened_a = tree_unflatten(tree_a, flattened_a) + unflattened_b = tree_unflatten(tree_b, flattened_b) +``` + +yields the following structure: + +``` python + unflattened_a = (1, 2, 3) + unflattened_b = MyClass(a=1, b =[1,2,3]) +``` + +unflattened_a should be structurally identical to a, and unflattened_b should be structurally identical to b. + +""" + + +def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]: + """ + Flatten a nested structure into a flat list of values and a tree definition. + + This function recursively traverses nested data structures (trees) and + flattens them into a linear list of leaf values, while preserving the + structure information in a PyTreeDef. + + Args: + x: The nested structure to flatten + + Returns: + tuple: (flat_values, treedef) where flat_values is a list of leaf values + and treedef is the tree structure definition + + Raises: + DSLTreeFlattenError: If the structure contains unsupported types + + Example: + >>> tree_flatten([1, [2, 3], 4]) + ([1, 2, 3, 4], PyTreeDef(...)) + """ + children_iter, treedef = _tree_flatten(x) + return list(children_iter), treedef + + +def get_registered_node_types_or_insert(x: Any) -> NodeType | None: + """ + Get the registered node type for an object, registering it if necessary. + + This function checks if a type is already registered for pytree operations. + If not, it automatically registers the type based on its characteristics: + - Dynamic expressions get registered with dynamic expression handlers + - Dataclasses get registered with default dataclass handlers + + Args: + x: The object to get or register a node type for + + Returns: + NodeType or None: The registered node type, or None if the type + cannot be registered + """ + node_type = _node_types.get(type(x)) + if node_type: + return node_type + elif is_dynamic_expression(x): + # If a class implements DynamicExpression protocol, register it before default dataclass one + return register_pytree_node( + type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable + ) + elif dataclasses.is_dataclass(x): + return register_pytree_node( + type(x), default_dataclass_to_iterable, default_dataclass_from_iterable + ) + else: + return None + + +def create_leaf_for_value( + x: Any, + is_numeric: bool = False, + is_none: bool = False, + node_metadata: SimpleNamespace = None, + ir_type_str: str = None, +) -> Leaf: + """ + Create a Leaf node for a given value. + + Args: + x: The value to create a leaf for + is_numeric: Whether this is a numeric value + is_none: Whether this represents None + node_metadata: Optional metadata + ir_type_str: Optional IR type string + + Returns: + Leaf: The created leaf node + """ + return Leaf( + is_numeric=is_numeric, + is_none=is_none, + node_metadata=node_metadata, + ir_type_str=ir_type_str or (str(x.type) if hasattr(x, "type") else None), + ) + + +def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]: + """ + Internal function to flatten a tree structure. + + This is the core implementation of tree flattening that handles different + types of objects including None, ArithValue, ir.Value, Numeric types, + and registered pytree node types. + + Args: + x: The object to flatten + + Returns: + tuple: (flattened_values, treedef) where flattened_values is an iterable + of leaf values and treedef is the tree structure + + Raises: + DSLTreeFlattenError: If the object type is not supported + """ + match x: + case None: + return [], create_leaf_for_value(x, is_none=True) + + case ArithValue() if is_dynamic_expression(x): + v = x.__extract_mlir_values__() + return v, create_leaf_for_value( + x, + node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), + ir_type_str=str(v[0].type), + ) + + case ArithValue(): + return [x], create_leaf_for_value(x, is_numeric=True) + + case ir.Value(): + return [x], create_leaf_for_value(x) + + case Numeric(): + v = x.__extract_mlir_values__() + return v, create_leaf_for_value( + x, + node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), + ir_type_str=str(v[0].type), + ) + + case _: + node_type = get_registered_node_types_or_insert(x) + if node_type: + node_metadata, children = node_type.to_iterable(x) + children_flat, child_trees = unzip2(map(_tree_flatten, children)) + flattened = it.chain.from_iterable(children_flat) + return flattened, PyTreeDef( + node_type, node_metadata, tuple(child_trees) + ) + + # Try to convert to numeric + try: + nval = as_numeric(x).ir_value() + return [nval], create_leaf_for_value(nval, is_numeric=True) + except Exception: + raise DSLTreeFlattenError( + "Flatten Error", get_fully_qualified_class_name(x) + ) + + +def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any: + """ + Reconstruct a nested structure from a flat list of values and tree definition. + + This is the inverse operation of tree_flatten. It takes the flattened + values and the tree structure definition to reconstruct the original + nested structure. + + Args: + treedef: The tree structure definition from tree_flatten + xs: List of flat values to reconstruct from + + Returns: + The reconstructed nested structure + + Example: + >>> flat_values, treedef = tree_flatten([1, [2, 3], 4]) + >>> tree_unflatten(treedef, flat_values) + [1, [2, 3], 4] + """ + return _tree_unflatten(treedef, iter(xs)) + + +def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any: + """ + Internal function to reconstruct a tree structure. + + This is the core implementation of tree unflattening that handles + different types of tree definitions including Leaf nodes and PyTreeDef nodes. + + Args: + treedef: The tree structure definition + xs: Iterator of flat values to reconstruct from + + Returns: + The reconstructed object + """ + match treedef: + case Leaf(is_none=True): + return None + + case Leaf( + node_metadata=metadata + ) if metadata and metadata.is_dynamic_expression: + return metadata.original_obj.__new_from_mlir_values__([next(xs)]) + + case Leaf(is_numeric=True): + return as_numeric(next(xs)) + + case Leaf(): + return next(xs) + + case PyTreeDef(): + children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs) + return treedef.node_type.from_iterable(treedef.node_metadata, children) + + +def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool: + """ + Check if two tree definitions are structurally equal. + + This is a helper function for check_tree_equal that recursively compares + tree structures. + + Args: + lhs: Left tree definition (PyTreeDef or Leaf) + rhs: Right tree definition (PyTreeDef or Leaf) + + Returns: + bool: True if the trees are structurally equal, False otherwise + """ + match (lhs, rhs): + case (Leaf(), Leaf()): + return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str + + case (PyTreeDef(), PyTreeDef()): + lhs_metadata = lhs.node_metadata + rhs_metadata = rhs.node_metadata + + lhs_fields = getattr(lhs_metadata, "fields", []) + rhs_fields = getattr(rhs_metadata, "fields", []) + lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", []) + rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", []) + + return ( + lhs.node_type == rhs.node_type + and lhs_fields == rhs_fields + and lhs_constexpr_fields == rhs_constexpr_fields + and len(lhs.child_treedefs) == len(rhs.child_treedefs) + and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs)) + ) + + case _: + return False + + +def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int: + """ + Check if two tree definitions are equal and return the index of first difference. + + This function compares two tree definitions and returns the index of the + first child that differs, or -1 if they are completely equal. + + Args: + lhs: Left tree definition + rhs: Right tree definition + + Returns: + int: Index of the first differing child, or -1 if trees are equal + + Example: + >>> treedef1 = tree_flatten([1, [2, 3]])[1] + >>> treedef2 = tree_flatten([1, [2, 4]])[1] + >>> check_tree_equal(treedef1, treedef2) + 1 # The second child differs + """ + assert len(lhs.child_treedefs) == len(rhs.child_treedefs) + + def find_first_difference( + index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]] + ) -> int: + index, (l, r) = index_and_pair + return index if not _check_tree_equal(l, r) else -1 + + differences = map( + find_first_difference, enumerate(zip(lhs.child_treedefs, rhs.child_treedefs)) + ) + return next((diff for diff in differences if diff != -1), -1) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bdd259c0203aaca3c7a7e31e64a576630f369a9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/__init__.py @@ -0,0 +1,213 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +import logging +import os +import sys + +import cutlass_library + + +def _cuda_install_path_from_nvcc() -> str: + import subprocess + # Attempt to detect CUDA_INSTALL_PATH based on location of NVCC + result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True) + if result.returncode != 0: + raise Exception(f'Unable to find nvcc via `which` utility.') + + cuda_install_path = result.stdout.decode('utf-8').split('/bin/nvcc')[0] + if not os.path.isdir(cuda_install_path): + raise Exception(f'Environment variable "CUDA_INSTALL_PATH" is not defined, ' + f'and default path of {cuda_install_path} does not exist.') + + return cuda_install_path + + +CUTLASS_PATH = os.getenv("CUTLASS_PATH", cutlass_library.source_path) + +# Alias CUTLASS_PATH as source_path +source_path = CUTLASS_PATH + +_NVCC_VERSION = None +def nvcc_version(): + global _NVCC_VERSION + if _NVCC_VERSION is None: + import subprocess + + # Attempt to get NVCC version + result = subprocess.run(['nvcc', '--version'], capture_output=True) + if result.returncode != 0: + raise Exception('Unable to run `nvcc --version') + _NVCC_VERSION = str(result.stdout).split(" release ")[-1].split(",")[0] + return _NVCC_VERSION + +_CUDA_INSTALL_PATH = None +def cuda_install_path(): + """ + Helper method for on-demand fetching of the CUDA installation path. This allows + the import of CUTLASS to proceed even if NVCC is not available, preferring to + raise this error only when an operation that needs NVCC is being performed. + """ + global _CUDA_INSTALL_PATH + if _CUDA_INSTALL_PATH is None: + _CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc()) + return _CUDA_INSTALL_PATH + +CACHE_FILE = "compiled_cache.db" + +from cutlass_library import ( + DataType, + EpilogueScheduleType, + KernelScheduleType, + MathOperation, + LayoutType, + OpcodeClass, + TileDescription, + TileSchedulerType, +) + +this = sys.modules[__name__] +this.logger = logging.getLogger(__name__) + +# RMM is only supported for Python 3.9+ +if (sys.version_info.major == 3 and sys.version_info.minor > 8) or sys.version_info.major > 3: + try: + import rmm + this.use_rmm = True + except ImportError: + this.use_rmm = False +else: + this.use_rmm = False + + +def set_log_level(level: int): + """ + Sets the log level + + :param log_level: severity of logging level to use. See https://docs.python.org/3/library/logging.html#logging-levels for options + :type log_level: int + """ + this.logger.setLevel(level) + +set_log_level(logging.ERROR) + +from cutlass_cppgen.library_defaults import OptionRegistry +from cutlass_cppgen.backend.utils.device import device_cc + +this._option_registry = None +def get_option_registry(): + """ + Helper method for on-demand initialization of the options registry. This avoids building + the registry when CUTLASS is imported. + """ + if this._option_registry is None: + this.logger.info("Initializing option registry") + this._option_registry = OptionRegistry(device_cc()) + return this._option_registry + +this.__version__ = '4.2.1' + +from cutlass_cppgen.backend import create_memory_pool +from cutlass_cppgen.emit.pytorch import pytorch +from cutlass_cppgen.op.gemm import Gemm +from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad +from cutlass_cppgen.op.gemm_grouped import GroupedGemm +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.backend.evt.ir.tensor import Tensor +from cutlass_cppgen.utils.lazy_import import lazy_import + + +this.memory_pool = None +def get_memory_pool(): + """" + Helper method for on-demand memory pool. This avoids allocating the memory pool unnecessarily + whe CUTLASS is imported. + """ + if this.use_rmm and this.memory_pool is None: + this.memory_pool = create_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32) + return this.memory_pool + + +base_cuda = lazy_import("cuda") +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") + +this._device_id = None +this._nvcc_version = None + +def check_cuda_versions(): + # Strip any additional information from the CUDA version + _cuda_version = base_cuda.__version__.split("rc")[0] + # Check that Python CUDA version exceeds NVCC version + this._nvcc_version = nvcc_version() + _cuda_list = _cuda_version.split('.') + _nvcc_list = this._nvcc_version.split('.') + for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list): + if int(val_cuda) < int(val_nvcc): + raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}") + + if len(_nvcc_list) > len(_cuda_list): + if len(_nvcc_list) != len(_cuda_list) + 1: + raise Exception(f"Malformatted NVCC version of {this._nvcc_version}") + if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0: + raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}") + +def initialize_cuda_context(): + check_cuda_versions() + + if this._device_id is not None: + return + + if this.use_rmm: + # This also covers initializing the CUDA context + get_memory_pool() + + device_id = os.getenv("CUTLASS_CUDA_DEVICE_ID") + if device_id is None: + if not this.use_rmm: + # Manually call cuInit() and create context by making a runtime API call + err, = cudart.cudaFree(0) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + + err, device_count = cuda.cuDeviceGetCount() + if err != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"cuDeviceGetCount failed with error {err}") + if device_count <= 0: + raise Exception("No CUDA devices found") + device_id = 0 + + this._device_id = int(device_id) + + +def device_id() -> int: + initialize_cuda_context() + return this._device_id diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59cfaf7154687fa3a971f2221f0cce2130ff1a4f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/__init__.py @@ -0,0 +1,48 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.arguments import * +from cutlass_cppgen.backend.c_types import * +from cutlass_cppgen.backend.compiler import ArtifactManager +from cutlass_cppgen.backend.conv2d_operation import * +from cutlass_cppgen.backend.epilogue import * +from cutlass_cppgen.backend.frontend import * +from cutlass_cppgen.backend.gemm_operation import * +from cutlass_cppgen.backend.library import * +from cutlass_cppgen.backend.memory_manager import PoolMemoryManager, create_memory_pool +from cutlass_cppgen.backend.operation import * +from cutlass_cppgen.backend.reduction_operation import * +from cutlass_cppgen.backend.type_hint import * +from cutlass_cppgen.backend.utils import * +from cutlass_cppgen.backend.utils.device import device_cc + +compiler = ArtifactManager() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/arguments.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b0656a89a8b0a42b864429810b74bc433582d4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/arguments.py @@ -0,0 +1,136 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from math import prod +from typing import Union + +from cutlass_cppgen.utils.lazy_import import lazy_import + +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +import numpy as np + +import cutlass_cppgen +from cutlass_cppgen.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend +from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor + + +class ArgumentBase: + """ + Base class for operation arguments + """ + + def __init__( + self, + A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + **kwargs, + ) -> None: + # tensor_C can be interpreted as the bias with bias=True in keyword args + self.bias = kwargs.get("bias", False) + + self.stream = kwargs.get("stream", cuda.CUstream(0)) + + # RMM buffers used to track tensor lifetime + self.buffers = {} + # Host tensor to copy the computed result back + self.host_tensors = {} + + self.ptr_A = self.tensor_to_ptr(A, "A") + self.ptr_B = self.tensor_to_ptr(B, "B") + self.ptr_C = self.tensor_to_ptr(C, "C") + self.ptr_D = self.tensor_to_ptr(D, "D", is_output=True) + if C is not None: + if not isinstance(C, cuda.CUdeviceptr): + self.tensor_c_numel = prod(C.shape) + + def tensor_to_ptr(self, tensor, name, is_output=False): + """ + Convert and remember the input tensor to cuda.CUdeviceptr used by cuda python + For numpy.ndarray, it also remembers the host buffer for synchronization + """ + if tensor is None: + return cuda.CUdeviceptr(0) + if is_numpy_tensor(tensor): + if is_output: + assert name + self.buffers[name] = NumpyFrontend.argument(tensor, is_output) + if is_output: + self.host_tensors[name] = tensor + return self.buffers[name].ptr + elif is_torch_tensor(tensor): + return TorchFrontend.argument(tensor) + elif isinstance(tensor, cuda.CUdeviceptr): + return tensor + elif is_cupy_tensor(tensor): + return CupyFrontend.argument(tensor) + else: + raise TypeError("Unsupported Frontend. Only support numpy and torch") + + def sync(self, stream_sync=True): + if stream_sync: + (err,) = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + for key in self.host_tensors.keys(): + host_tensor = self.host_tensors[key] + (err,) = cuda.cuMemcpyDtoH( + host_tensor, + self.buffers[key].ptr, + host_tensor.size * host_tensor.itemsize, + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + self.free() + + def free(self): + """ + Frees allocated device-side memory + """ + # Free any device memory allocated manually + if not cutlass_cppgen.use_rmm: + for name, buf in self.buffers.items(): + if isinstance(buf, DevicePtrWrapper): + err, = cudart.cudaFree(buf.ptr) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + + if hasattr(self, "workspace_buffer") and isinstance(self.workspace_buffer, DevicePtrWrapper): + err, = cudart.cudaFree(self.workspace_buffer.ptr) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + del self.workspace_buffer diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/c_types.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/c_types.py new file mode 100644 index 0000000000000000000000000000000000000000..3f515aa38439e4b2e1392659d188cbe6a68e0481 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/c_types.py @@ -0,0 +1,625 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ctypes + +from cutlass_library import ( + DataType, + KernelScheduleType, + TileSchedulerType +) +from cutlass_cppgen.backend.library import DataTypeSizeBytes + + +class GemmCoord_(ctypes.Structure): + _fields_ = [ + ("m", ctypes.c_int), + ("n", ctypes.c_int), + ("k", ctypes.c_int) + ] + + def __init__(self, m, n, k) -> None: + self.m = m + self.n = n + self.k = k + + +class GemmCoordBatched_(ctypes.Structure): + """ + Wrapper around a GemmCoord that also contains batch count. This is used for encoding + batched GEMM inputs to CUTLASS 3 GEMMs. + """ + + _fields_ = [ + ("m", ctypes.c_int), + ("n", ctypes.c_int), + ("k", ctypes.c_int), + ("batch_count", ctypes.c_int) + ] + + def __init__(self, gemm_coord, batch_count) -> None: + self.m = gemm_coord.m + self.n = gemm_coord.n + self.k = gemm_coord.k + self.batch_count = batch_count + + +class MatrixCoord_(ctypes.Structure): + _fields_ = [ + ("row", ctypes.c_int), + ("column", ctypes.c_int) + ] + + +class dim3_(ctypes.Structure): + _fields_ = [ + ("x", ctypes.c_int), + ("y", ctypes.c_int), + ("z", ctypes.c_int) + ] + + +class StrideBatched_(ctypes.Structure): + """ + CUTLASS 3.0 strides for operands contain one static dimension and two variable dimensions. The + variable dimensions represent the stride along non-unit-stride dimension of the row/column major + layout, and the batch stride. This structure encodes the two variable dimensions. + """ + _fields_ = [ + ("major_stride", ctypes.c_int64), + ("batch_stride", ctypes.c_int64) + ] + + + +class GenericMainloopArguments3x_(ctypes.Structure): + """ + Structure representing the superset of possible mainloop arguments. + This structure should not be passed to kernels directly, but, rather, + be used as an input to one of the more specific schedule arguments, which + will each select those arguments relevant to the particular schedule. + """ + _fields_ = [ + ("ptr_A", ctypes.c_void_p), + ("stride_A", StrideBatched_), + ("ptr_B", ctypes.c_void_p), + ("stride_B", StrideBatched_), + ("mma_promotion_interval", ctypes.c_int) + ] + + +class _PersistentTileSchedulerArguments(ctypes.Structure): + _fields_ = [ + ("max_swizzle_size", ctypes.c_int), + ("raster_order_option", ctypes.c_int), + ] + + +class _PersistentTileSchedulerStreamKArguments(ctypes.Structure): + _fields_ = [ + ("splits", ctypes.c_int), + ("max_swizzle_size", ctypes.c_int), + ("raster_order_option", ctypes.c_int), + ("reduction_mode", ctypes.c_int), + ("decomposition_mode", ctypes.c_int), + ] + + +def get_tile_scheduler_arguments_3x( + tile_scheduler: TileSchedulerType, + splits: int = 1): + max_swizzle_size = 1 + raster_order_option = 0 # Heuristic + if tile_scheduler in [TileSchedulerType.Default, TileSchedulerType.Persistent]: + return _PersistentTileSchedulerArguments( + max_swizzle_size, + raster_order_option, + ) + elif tile_scheduler == TileSchedulerType.StreamK: + reduction_mode = 0 # Deterministic + decomposition_mode = 0 # Heuristic + return _PersistentTileSchedulerStreamKArguments( + splits, + max_swizzle_size, + raster_order_option, + reduction_mode, + decomposition_mode, + ) + + +def get_mainloop_arguments_3x( + kernel_schedule: KernelScheduleType, + element_A, + element_B, + alignment_A: int, + alignment_B: int) -> ctypes.Structure: + """ + Returns the ctypes structure to be used for the 3.x kernel's mainloop parameters. + + :param kernel_schedule: type of kernel schedule to be used in the mainloop + :type kernel_schedule: cutlass_library.KernelScheduleType + :param element_A: data type of operand A + :param element_B: data type of operand B + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + + :returns: ctypes structure to be used for the 3.x kernel's mainloop parameters + :rtype: ctypes.Structure + """ + class _MainloopArgumentsTma(ctypes.Structure): + _fields_ = [ + ("ptr_A", ctypes.c_void_p), + ("stride_A", StrideBatched_), + ("ptr_B", ctypes.c_void_p), + ("stride_B", StrideBatched_), + ("mma_promotion_interval", ctypes.c_int) + ] + + @staticmethod + def from_generic_mainloop_args(args: GenericMainloopArguments3x_): + return _MainloopArgumentsTma( + args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, + args.mma_promotion_interval + ) + + class _MainloopArgumentsMultistage(ctypes.Structure): + _fields_ = [ + ("ptr_A", ctypes.c_void_p), + ("stride_A", StrideBatched_), + ("ptr_B", ctypes.c_void_p), + ("stride_B", StrideBatched_), + ] + + @staticmethod + def from_generic_mainloop_args(args: GenericMainloopArguments3x_): + return _MainloopArgumentsMultistage( + args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, + ) + + # Currently all 3.x kernels (CpAsync and Tma) have the same argument structure. + # Should that become not the case, this is the place to return custom ctypes + # structures based on selected kernel schedule. + return _MainloopArgumentsTma + + +def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue): + if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt + else: + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + if hasattr(epilogue_functor, "visitor"): + class _EpilogueArguments(ctypes.Structure): + _fields_ = [ + ("epilogue", _EpilogueOutputOpParams), + ("arg_C", epilogue_functor.arg_c_type), + ("arg_D", epilogue_functor.arg_d_type) + ] + + def __init__(self, output_op, ptr_c, stride_c, ptr_d, stride_d) -> None: + self.epilogue = output_op + self.arg_C = epilogue_functor.arg_c_type(ptr_c) + self.arg_D = epilogue_functor.arg_d_type(ptr_d) + else: + class _EpilogueArguments(ctypes.Structure): + _fields_ = [ + ("epilogue", _EpilogueOutputOpParams), + ("ptr_C", ctypes.c_void_p), + ("stride_C", StrideBatched_), + ("ptr_D", ctypes.c_void_p), + ("stride_D", StrideBatched_), + ] + + class _HardwareInfo(ctypes.Structure): + _fields_ = [ + ("device_id", ctypes.c_int), + ("sm_count", ctypes.c_int), + ("max_active_clusters", ctypes.c_int), + ("cluster_shape", dim3_), + ("cluster_shape_fallback", dim3_), + ] + + class _GemmArguments(ctypes.Structure): + _fields_ = [ + ("mode", ctypes.c_int), + ("problem_size", GemmCoordBatched_), + ("mainloop", mainloop_arguments), + ("epilogue", _EpilogueArguments), + ("hw_info", _HardwareInfo), + ("scheduler", type(scheduler_args)), + ] + + return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo + + +def get_gemm_arguments(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _GemmArguments(ctypes.Structure): + _fields_ = [ + # Arguments from UniversalArgumentsBase + ("mode", ctypes.c_int), + ("problem_size", GemmCoord_), + ("batch_count", ctypes.c_int), + ("batch_stride_D", ctypes.c_longlong), + # Remaining arguments + ("epilogue", _EpilogueOutputOpParams), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("batch_stride_A", ctypes.c_longlong), + ("batch_stride_B", ctypes.c_longlong), + ("batch_stride_C", ctypes.c_longlong), + ("stride_a", ctypes.c_longlong), + ("stride_b", ctypes.c_longlong), + ("stride_c", ctypes.c_longlong), + ("stride_d", ctypes.c_longlong), + ("lda", ctypes.c_longlong), + ("ldb", ctypes.c_longlong), + ("ldc", ctypes.c_longlong), + ("ldd", ctypes.c_longlong), + ("ptr_gather_A_indices", ctypes.c_void_p), + ("ptr_gather_B_indices", ctypes.c_void_p), + ("ptr_scatter_D_indices", ctypes.c_void_p) + ] + + return _GemmArguments, _EpilogueOutputOpParams + + +def get_gemm_arguments_streamk(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _GemmArguments(ctypes.Structure): + _fields_ = [ + ("mode", ctypes.c_int), + ("problem_size", GemmCoord_), + ("batch_count", ctypes.c_int), + ("epilogue", _EpilogueOutputOpParams), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("batch_stride_A", ctypes.c_longlong), + ("batch_stride_B", ctypes.c_longlong), + ("batch_stride_C", ctypes.c_longlong), + ("batch_stride_D", ctypes.c_longlong), + ("stride_a", ctypes.c_longlong), + ("stride_b", ctypes.c_longlong), + ("stride_c", ctypes.c_longlong), + ("stride_d", ctypes.c_longlong), + ("lda", ctypes.c_longlong), + ("ldb", ctypes.c_longlong), + ("ldc", ctypes.c_longlong), + ("ldd", ctypes.c_longlong), + ("avail_sms", ctypes.c_int) + ] + + return _GemmArguments, _EpilogueOutputOpParams + + +########################################################################################### +# GEMM Grouped +########################################################################################### + + +def get_gemm_grouped_arguments(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _GEMMGroupedArguments(ctypes.Structure): + _fields_ = [ + ("problem_sizes", ctypes.c_void_p), + ("problem_count", ctypes.c_int), + ("threadblock_count", ctypes.c_int), + ("output_op", _EpilogueOutputOpParams), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("lda", ctypes.c_void_p), + ("ldb", ctypes.c_void_p), + ("ldc", ctypes.c_void_p), + ("ldd", ctypes.c_void_p), + ("host_problem_sizes", ctypes.c_void_p) + ] + + return _GEMMGroupedArguments, _EpilogueOutputOpParams + + +############################################################################################ +# Convolution2D +############################################################################################ + + +class Conv2DProblemSize_(ctypes.Structure): + _fields_ = [ + ("N", ctypes.c_int), + ("H", ctypes.c_int), + ("W", ctypes.c_int), + ("C", ctypes.c_int), + ("P", ctypes.c_int), + ("Q", ctypes.c_int), + ("K", ctypes.c_int), + ("R", ctypes.c_int), + ("S", ctypes.c_int), + ("pad_h", ctypes.c_int), + ("pad_w", ctypes.c_int), + ("stride_h", ctypes.c_int), + ("stride_w", ctypes.c_int), + ("dilation_h", ctypes.c_int), + ("dilation_w", ctypes.c_int), + ("mode", ctypes.c_int), # kCrossCorrelation: 0, kConvolution: 1 + ("split_k_slices", ctypes.c_int), + ("groups", ctypes.c_int) + ] + + def __init__(self, problem_size) -> None: + for field_name, _ in self._fields_: + setattr(self, field_name, getattr(problem_size, field_name)) + + +class Layout4D(ctypes.Structure): + _fields_ = [("stride", ctypes.c_int * 3)] + + def __init__(self, tensor_ref): + stride = tensor_ref.stride() + setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2))) + + +class TensorRef_(ctypes.Structure): + _fields_ = [ + ("ptr", ctypes.c_void_p), + ("layout", Layout4D) + ] + + def __init__(self, tensor_ref): + setattr(self, "ptr", tensor_ref.data()) + setattr(self, "layout", Layout4D(tensor_ref.layout())) + + +class TensorRef2D_(ctypes.Structure): + _fields_ = [ + ("ptr", ctypes.c_void_p), + ("stride", ctypes.c_int) + ] + + +def get_conv2d_arguments(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _Conv2dArguments(ctypes.Structure): + _fields_ = [ + ("conv_kind", ctypes.c_int), + ("problem_size", Conv2DProblemSize_), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("tensor_C_numel", ctypes.c_int), + ("output_op", _EpilogueOutputOpParams), + ("split_k_mode", ctypes.c_int) + ] + + return _Conv2dArguments, _EpilogueOutputOpParams + + +############################################################################################ +# Reduction +############################################################################################ + + +def get_reduction_params(epilogue_functor): + _EpilogueOutputParams = epilogue_functor.epilogue_type + + class _ReductionParams(ctypes.Structure): + _fields_ = [ + ("problem_size", MatrixCoord_), + ("partitions", ctypes.c_int), + ("partition_stride", ctypes.c_longlong), + ("workspace", TensorRef2D_), + ("destination", TensorRef2D_), + ("source", TensorRef2D_), + ("output_op", _EpilogueOutputParams), + ] + + return _ReductionParams, _EpilogueOutputParams + + +########################################################################################### +# Epilogue Visitor Type Factory +########################################################################################### + +class Empty(ctypes.Structure): + _fields_ = [] + + def __init__(self, *arg) -> None: + pass + +class EmptyByte(ctypes.Structure): + _fields_ = [ + ("byte", ctypes.c_byte) + ] + + def __init__(self, *arg) -> None: + pass + +class EBO: + def __init__(self, index: int, type) -> None: + self.index = index + self.type = type + + def __eq__(self, other) -> bool: + if isinstance(other, EBO): + return self.index == other.index and self.type == other.type + return False + + def __hash__(self) -> int: + return hash((self.index, self.type)) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self) -> str: + return f"<{self.index}, {self.type}>" + + +def tuple_factory_(input_tuple, dtype, constants=[0,1]): + """ + The factory function generating cute::Tuple with input tuple + :param input_tuple: the input tuple + :type input_tuple: tuple + :param dtype: the data type for non-constant values + :type dtype: str, "int32_t", "int", "int64_t" + :param constant: the values that will be treated as constants + :type constant: list[int] + + :return: ctype structure representing the cute::Tuple + :return: the empty base classes of the tuple + """ + + # The empty base classes of the current tuple + empty_bases = [] + # The first non empty base class + first_non_empty_base = None + # The ctype fields of the current tuple + ctype_fields = [] + + for idx, entry in enumerate(input_tuple): + # For nested tuples + if isinstance(entry, tuple): + sub_tuple_ctype, sub_empty_bases = tuple_factory_(entry, dtype, constants) + if ctypes.sizeof(sub_tuple_ctype) == 0: + # The empty tuple base class is also an empty EBO + empty_bases.append(EBO(idx, entry)) + else: + if first_non_empty_base is None: + first_non_empty_base = sub_empty_bases + ctype_fields.append((f"entry_{idx}", sub_tuple_ctype)) + else: + if entry in constants: + empty_bases.append(EBO(idx, entry)) + ctype_fields.append((f"entry_{idx}", Empty)) + else: + ctype_fields.append((f"entry_{idx}", dtype)) + if first_non_empty_base is None: + first_non_empty_base = [] + + # Create the ctype tuple + class TupleType(ctypes.Structure): + _fields_ = ctype_fields + + def __init__(self, args) -> None: + fields = self._fields_ + + assert len(fields) == len(args) + for field, arg in zip(fields, args): + name = field[0] + field_type = field[1] + setattr(self, name, field_type(arg)) + + return TupleType, empty_bases + +def tuple_factory(input_tuple, dtype: str, constants=[0,1]): + """ + The factory function generating cute::Tuple with input tuple + :param input_tuple: the input tuple + :type input_tuple: tuple + :param dtype: the data type for non-constant values + :type dtype: str, "int32_t", "int", "int64_t" + :param constant: the values that will be treated as constants + :type constant: list[int] + + :return: ctype structure representing the cute::Tuple + :return: the empty base classes of the tuple + """ + # Step 1: convert the dtype + if dtype == "int64_t": + dtype = ctypes.c_longlong + elif dtype in ["int", "int32_t"]: + dtype = ctypes.c_int32 + else: + raise NotImplementedError(f"Type {dtype} is not supported") + + tuple_type, _ = tuple_factory_(input_tuple, dtype, constants) + + if ctypes.sizeof(tuple_type) == 0: + return EmptyByte + return tuple_type + + +def visitor_factory(node_types, node_names): + """ + Creates the argument type of epilogue visitor type + + :param node_types: list of argument types under ctypes + :param node_names: list of argument names under str + + :return: tuple type in ctypes.Structure + """ + ctypes_field = [] + # Struct is used when number of nodes < 4 + # Because the Sm90VisitorImplBase has specification up to 4 nodes + # in `include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp` + if len(node_types) <= 4: + for idx, node_type in enumerate(node_types): + if ctypes.sizeof(node_type) == 0: + # Special case for empty struct + # 1 byte placeholder is used for correct alignment + ctypes_field.append((node_names[idx], ctypes.c_byte)) + else: + ctypes_field.append((node_names[idx], node_type)) + + class VisitorType(ctypes.Structure): + _fields_ = ctypes_field + + def __init__(self, kwargs) -> None: + for field in self._fields_: + fname, ftype = field + if ftype != ctypes.c_byte: + setattr(self, fname, ftype(kwargs)) + + # For cases with more than 4 nodes, tuple is used + else: + for idx, node_type in enumerate(node_types): + ctypes_field.append((node_names[idx], node_type)) + + class VisitorType(ctypes.Structure): + _fields_ = ctypes_field + + def __init__(self, kwargs) -> None: + for field in self._fields_: + fname, ftype = field + setattr(self, fname, ftype(kwargs)) + + return VisitorType diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/compiler.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..0b66ce8a2402a109e2da00613e7255760685855c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/compiler.py @@ -0,0 +1,462 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ctypes +import json +import os +import sqlite3 +import subprocess +import tempfile + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +nvrtc = lazy_import("cuda.nvrtc") +from cutlass_library import SubstituteTemplate + +import cutlass_cppgen +from cutlass_cppgen import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger +from cutlass_cppgen.backend.gemm_operation import GemmOperationUniversal +from cutlass_cppgen.backend.library import ApiVersion +from cutlass_cppgen.backend.utils.device import device_cc + +IncludeTemplate = r"""#include "${include}" +""" + + +def compile_with_nvcc(cmd, source, error_file): + succeed = True + try: + subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + error_message = e.output.decode() + with open(error_file, "w") as error_out: + error_log = "Compilation error for the following kernel: \n" + error_log += source + error_log += "\nError Message:\n" + error_log += error_message + error_out.write(error_log) + succeed = False + if not succeed: + # Print the error log to stdout if log level is set to warning or higher + # verbosity. Otherwise, simply point to the error log file. + logger.warning(error_log) + raise Exception(f"Invalid Kernel. See '{error_file}' for details.") + + +class CompilationOptions: + """ + Compilation options. + """ + + def __init__(self, flags, arch, include_paths=[]): + self.includes = [] + self.include_paths = include_paths + self.flags = flags + self.arch = arch + + def get_str(self): + opts = [] + for flag in self.flags: + opts.append(flag) + + for incl in self.include_paths: + opts.append(f"--include-path={incl}") + + arch_flag = f"-arch=sm_{self.arch}" + if self.arch in [90, 100, 101, 103, 120, 121] and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12: + arch_flag += "a" + opts.append(arch_flag) + + return " ".join(opts) + + def get(self): + options = [] + + for flag in self.flags: + options.append(bytes(str.encode(flag))) + + for incl in self.include_paths: + options.append(bytes(str.encode(f" --include-path={incl}"))) + + arch_flag = f" -arch=sm_{self.arch}" + if self.arch in [90, 100, 101, 103, 120, 121]: + arch_flag += "a" + + options.append(bytes(str.encode(arch_flag))) + + return options + + +def convertToBinaryData(filename): + with open(filename, "rb") as file: + blobData = file.read() + return blobData + + +def CDLLBin(host_binary): + tempfile.tempdir = "./" + temp_so = tempfile.NamedTemporaryFile(prefix="host_func", suffix=".so", delete=True) + with open(temp_so.name, "wb") as file: + file.write(host_binary) + host_lib = ctypes.CDLL(temp_so.name) + return host_lib + + +class ArtifactManager: + """ + Artifact manager + """ + + def __init__(self) -> None: + connection = sqlite3.connect(CACHE_FILE) + cursor = connection.cursor() + # Create the table if it does not already exist + sqlite_create_table_query = """ + CREATE TABLE IF NOT EXISTS compiled_operations(op_key TEXT NOT NULL UNIQUE, + cubin BLOB NOT NULL, + hostbin BLOB NOT NULL, + op_name TEXT NOT NULL, + op_attrs TEXT NOT NULL) + """ + cursor.execute(sqlite_create_table_query) + connection.commit() + cursor.close() + + self._nvrtc_compile_options = ["-std=c++17", "-default-device"] + self._nvcc_compile_options = [ + "-std=c++17", + "--expt-relaxed-constexpr", + "-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored", + ] + self.nvcc() + self.compiled_cache_device = {} + self.compiled_cache_host = {} + + def nvrtc(self): + self.backend = "nvrtc" + self.default_compile_options = self._nvrtc_compile_options + + def nvcc(self): + self.backend = "nvcc" + self.default_compile_options = self._nvcc_compile_options + + def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): + connection = sqlite3.connect(CACHE_FILE) + cursor = connection.cursor() + sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)""" + + hostbin = convertToBinaryData(hostfile) + + data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs)) + + cursor.execute(sqlite_insert_blob_query, data_tuple) + connection.commit() + cursor.close() + + def load_operation(self, op_key, extra_funcs): + connection = sqlite3.connect(CACHE_FILE) + cursor = connection.cursor() + sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?""" + cursor.execute(sqlite_fetch_blob_query, (op_key,)) + record = cursor.fetchall() + if len(record) == 0: + return False + for row in record: + key, cubin_image, host_binary, operation_name, op_attr = row + op_attr = json.loads(op_attr) + err, module = cuda.cuModuleLoadData(cubin_image) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("Cuda Error: {}".format(err)) + + err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name))) + self.compiled_cache_device[key] = kernel + + compiled_host_fns = {} + host_lib = CDLLBin(host_binary) + + func_name = operation_name + "_get_params" + func = getattr(host_lib, func_name) + func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0]) + compiled_host_fns["get_args"] = func + + func_name = operation_name + "_shared_memory_size" + func = getattr(host_lib, func_name) + compiled_host_fns["shared_memory_capacity"] = func() + + for attr in op_attr: + if isinstance(attr, str): + func_name = operation_name + "_" + attr + func = getattr(host_lib, func_name) + + # Set the return type of the function + if attr in extra_funcs and extra_funcs[attr] != None: + func.restype = extra_funcs[attr] + + compiled_host_fns[attr] = func + + self.compiled_cache_host[key] = compiled_host_fns + return True + + def emit_compile_(self, operation_list, compilation_options, host_compilation_options): + """ + Compile a list of kernels and store them into database + """ + source_buffer_device = "" + source_buffer_host = "" + # 1. include + includes = [] + for operation in operation_list: + for incl in operation.emitter.includes: + if incl not in includes: + includes.append(incl) + + includes_host = ["builtin_types.h", "device_launch_parameters.h", "cstddef"] + includes + for incl in includes: + source_buffer_device += SubstituteTemplate( + IncludeTemplate, + {"include": incl}, + ) + + for incl in includes_host: + source_buffer_host += SubstituteTemplate( + IncludeTemplate, + {"include": incl}, + ) + + # 2. Operations + for operation in operation_list: + source_buffer_device += operation.emit() + source_buffer_host += operation.emit() + values = { + "operation_name": operation.name(), + "operation_suffix": operation.emitter.operation_suffix, + } + source_buffer_device += SubstituteTemplate( + operation.KernelTemplate, + values, + ) + source_buffer_host += SubstituteTemplate(operation.HostTemplate, values) + + if self.backend == "nvrtc": + # 3. compile + err, program = nvrtc.nvrtcCreateProgram( + str.encode(source_buffer_device), + bytes(str.encode("module.cu")), + 0, [], []) + + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("NVRTC Error: {}".format(err)) + + # Compile program + options = compilation_options.get() + + err, = nvrtc.nvrtcCompileProgram(program, len(options), options) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + error_string = "NVRTC Error: {}\n".format(err) + + # Get log from compilation + err, logSize = nvrtc.nvrtcGetProgramLogSize(program) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("NVRTC Error: {}".format(err)) + + log = b" " * logSize + err, = nvrtc.nvrtcGetProgramLog(program, log) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("NVRTC Error: {}".format(err)) + + raise RuntimeError(error_string + log.decode() + source_buffer_device) + + # Get data from compilation + err, dataSize = nvrtc.nvrtcGetCUBINSize(program) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("NVRTC Error: {}".format(err)) + + cubin_image = b" " * dataSize + (err,) = nvrtc.nvrtcGetCUBIN(program, cubin_image) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("NVRTC Error: {}".format(err)) + + else: # with nvcc backend + # emit code + tempfile.tempdir = "./" + temp_cu = tempfile.NamedTemporaryFile( + prefix="kernel", suffix=".cu", delete=True) + temp_cubin = tempfile.NamedTemporaryFile( + prefix="kernel", suffix=".cubin", delete=True) + with open(temp_cu.name, "w") as file: + file.write(source_buffer_device) + + # compile with nvcc + cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}" + values = { + "cuda_install_path": cuda_install_path(), + "options": compilation_options.get_str(), + "srcfile": temp_cu.name, + "tarfile": temp_cubin.name, + } + cmd = SubstituteTemplate(cmd_template, values) + compile_with_nvcc(cmd.split(" "), source_buffer_device, "./cutlass_python_compilation_device_error.txt") + + # load the cubin image + with open(temp_cubin.name, "rb") as file: + cubin_image = file.read() + + tempfile.tempdir = "./" + temp_src = tempfile.NamedTemporaryFile( + prefix="host_src", suffix=".cu", delete=True) + + # Write the host source + with open(temp_src.name, "w") as outfile: + outfile.write(source_buffer_host) + + temp_dst = tempfile.NamedTemporaryFile( + prefix="host_func", suffix=".so", delete=True) + + # Set up host compilation arguments + cmd = [] + cmd.append(f"{cuda_install_path()}/bin/nvcc") + cmd.extend(["-x", "cu", "-Xcompiler=-fpermissive", "-Xcompiler=-w", "-Xcompiler=-fPIC"]) + cmd.extend(host_compilation_options.get_str().split(" ")) + cmd.extend(["-shared", "-o", temp_dst.name, temp_src.name, "-lcudart", "-lcuda"]) + + # Comile and load the library + compile_with_nvcc( cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt") + host_lib = ctypes.CDLL(temp_dst.name) + + return cubin_image, host_lib, temp_dst + + def add_module(self, operations, compile_options=None, bypass_cache=False): + """ + Insert a new compiled device module + """ + include_paths = [ + cuda_install_path() + "/include", + CUTLASS_PATH + "/include", + CUTLASS_PATH + "/tools/util/include", + CUTLASS_PATH + "/python/cutlass/cpp/include", + ] + + cutlass_cppgen.initialize_cuda_context() + arch = device_cc() + + host_compile_options = CompilationOptions( + self._nvcc_compile_options, arch, include_paths) + if compile_options is None: + compile_options = CompilationOptions( + self.default_compile_options, arch, include_paths) + # save the cubin + operation_key = [] + operation_list = [] + for operation in operations: + # step 1: get kernel string as key + key = operation.rt_module.emit() + operation.procedural_name() + self.backend + # step 1: check if the operation is in cache + compiled_kernel = self.compiled_cache_device.get(key) + + if compiled_kernel is None and not bypass_cache: + hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {})) + if hit: + compiled_kernel = self.compiled_cache_device.get(key) + assert compiled_kernel is not None + if compiled_kernel is not None: + operation.rt_module.kernel = compiled_kernel + compiled_host_fns = self.compiled_cache_host.get(key) + assert compiled_host_fns is not None + for key in compiled_host_fns.keys(): + setattr(operation.rt_module, key, compiled_host_fns[key]) + operation.rt_module.initialize() + else: + operation_list.append(operation.rt_module) + operation_key.append(key) + + if len(operation_list) > 0: + cubin_image, host_lib, host_file = self.emit_compile_( + operation_list, compile_options, host_compile_options) + + err, module = cuda.cuModuleLoadData(cubin_image) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("Cuda Error: {}".format(err)) + + operation_name = [] + operation_attr = [] + for operation, key in zip(operation_list, operation_key): + # get device kernels + err, operation.kernel = cuda.cuModuleGetFunction( + module, + bytes(str.encode(operation.name())) + ) + operation_name.append(operation.name()) + self.compiled_cache_device[key] = operation.kernel + # get host functions + compiled_host_fns = {} + op_attr = [] + + # get param size + func_name = operation.name() + "_get_param_size" + func = getattr(host_lib, func_name) + param_size = func() + + func_name = operation.name() + "_get_params" + func = getattr(host_lib, func_name) + func.argtype = operation.argtype + func.restype = ctypes.POINTER(ctypes.c_char * param_size) + setattr(operation, "get_args", func) + compiled_host_fns["get_args"] = func + + # set shared memory size + func_name = operation.name() + "_shared_memory_size" + func = getattr(host_lib, func_name) + setattr(operation, "shared_memory_capacity", func()) + compiled_host_fns["shared_memory_capacity"] = func() + # set the maximum dynamic shared size + operation.initialize() + + # get extra functions + op_attr.append(param_size) + + if hasattr(operation, "extra_funcs"): + for suffix, ret_type in operation.extra_funcs.items(): + func_name = operation.name() + "_" + suffix + func = getattr(host_lib, func_name) + if ret_type is not None: + func.restype = ret_type + setattr(operation, suffix, func) + compiled_host_fns[suffix] = func + op_attr.append(suffix) + + operation_attr.append(op_attr) + self.compiled_cache_host[key] = compiled_host_fns + + for (key, operation_name, operation_attr,) in zip(operation_key, operation_name, operation_attr): + self.insert_operation( + key, cubin_image, host_file.name, operation_name, operation_attr) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/conv2d_operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/conv2d_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..03679c434e1a63e9d1f9f2d1571dacedcf6e1470 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/conv2d_operation.py @@ -0,0 +1,700 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from __future__ import annotations + +import ctypes +from typing import Union + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +from cutlass_library import SubstituteTemplate +import numpy as np + +from cutlass_library import ( + ConvKindNames, + ConvKindTag, + DataTypeNames, + DataTypeSize, + DataTypeTag, + IteratorAlgorithmNames, + IteratorAlgorithmTag, + LayoutTag, + LayoutType, + MathOperation, + MathOperationTag, + OpcodeClass, + OpcodeClassNames, + OpcodeClassTag, + OperationKind, + ShortDataTypeNames, + ShortLayoutTypeNames, + SplitKMode, + StrideSupport, + StrideSupportTag, + SwizzlingFunctor, + SwizzlingFunctorTag, + get_complex_from_real, +) + +from cutlass_cppgen.backend.arguments import ArgumentBase +from cutlass_cppgen.backend.c_types import dim3_, get_conv2d_arguments +from cutlass_cppgen.backend.library import ( + EmissionType, + TensorDescription, + TileDescription, +) +from cutlass_cppgen.backend.memory_manager import device_mem_alloc +from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass_cppgen.backend.utils.device import to_device_ptr +from cutlass_cppgen.shape import GemmCoord + + +class Conv2dArguments(ArgumentBase): + """ + Argument wrapper for Conv2d. It encodes problem information and + user-provide tensors into the kernel's argument. + + :param operation: the Conv2d operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.Conv2dOperation` + :param problem_size: the Conv2d problem size + :type problem_size: :class:`cutlass_cppgen.shape.Conv2dProblemSize` + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + :param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial + :type split_k_mode: cutlass_library.library.SplitKMode, optional + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + """ + + def __init__(self, operation, problem_size, A, B, C, D, + split_k_mode=SplitKMode.Serial, **kwargs, ) -> None: + self.operation = operation + self.conv_kind = operation.conv_kind + self.layout_A = operation.A.layout + self.layout_B = operation.B.layout + self.layout_C = operation.C.layout + + self.element_A = operation.A.element + self.element_B = operation.B.element + self.element_C = operation.C.element + + if self.layout_C == LayoutType.TensorNC32HW32: + raise Exception("Layout type TensorNC32HW32 is not currently supported") + + super().__init__(A, B, C, D, **kwargs) + + if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1: + self.split_k_mode = split_k_mode + self.split_k_slices = kwargs["split_k_slices"] + else: + self.split_k_mode = SplitKMode.Serial + self.split_k_slices = 1 + + if "output_op" in kwargs.keys() and self.split_k_mode != SplitKMode.Parallel: + self.output_op = kwargs["output_op"] + else: + self.output_op = self.operation.epilogue_type(1.0, 0.0) + + self.problem_size = problem_size + self.problem_size.split_k_slices = self.split_k_slices + + self.initialize() + + def get_arguments(self): + tc_numel = -1 + if hasattr(self, "tensor_c_numel"): + tc_numel = self.tensor_c_numel + + self.c_arguments = self.operation.argument_type( + int(self.conv_kind), + self.problem_size.ctype, + int(to_device_ptr(self.ptr_A)), + int(to_device_ptr(self.ptr_B)), + int(to_device_ptr(self.ptr_C)), + int(to_device_ptr(self.ptr_D)), + tc_numel, + self.output_op, + int(self.split_k_mode) + ) + + def initialize(self): + self.launch_config = self.operation.rt_module.plan(self) + + self.get_arguments() + + # Allocate and initialize device workspace + device_workspace_size = self.operation.rt_module.get_workspace_size(self.c_arguments) + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + self.semaphore = 0 + if workspace_ptr is not None and self.split_k_mode == SplitKMode.Parallel: + self.ptr_D = workspace_ptr + # Reset arguments now that ptr_D has been updated + self.get_arguments() + elif workspace_ptr is not None and self.split_k_mode == SplitKMode.Serial: + self.semaphore = workspace_ptr + + params_ = self.operation.rt_module.get_args( + self.c_arguments, ctypes.c_void_p(int(self.semaphore))) + self.host_workspace = bytearray(params_.contents) + self.device_workspace = None + + def sync(self): + """ + Synchronize the arguments. If the input tensor is in host, + copy it from device to host. + """ + return super().sync() + + +class Conv2dRT(ExecutableOperation): + """ + Conv2dRT manages the CUTLASS runtime components + """ + + KernelTemplate = r""" +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix} op; + + op(params, *shared_storage); +} + """ + + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + using ElementA = typename ${operation_name}_base::ElementA; + using ElementB = typename ${operation_name}_base::ElementB; + using ElementC = typename ${operation_name}_base::ElementC; + using LayoutA = typename ${operation_name}_base::LayoutA; + using LayoutB = typename ${operation_name}_base::LayoutB; + using LayoutC = typename ${operation_name}_base::LayoutC; + using EpilogueOutputOp = typename ${operation_name}_base::EpilogueOutputOp; + + struct ${operation_name}_TemporaryArgs { + int conv_kind; + cutlass::conv::Conv2dProblemSize problem_size; + ElementA* ptr_A; + ElementB* ptr_B; + ElementC* ptr_C; + ElementC* ptr_D; + int tensor_c_numel; + typename EpilogueOutputOp::Params epilogue_params; + int split_k_mode; + }; + + typename ${operation_name}${operation_suffix}::Arguments + construct_arguments(${operation_name}_TemporaryArgs args) { + cutlass::conv::Operator conv_operator = static_cast(args.conv_kind); + auto tc_A = cutlass::conv::implicit_gemm_tensor_a_extent(conv_operator, args.problem_size); + auto tc_B = cutlass::conv::implicit_gemm_tensor_b_extent(conv_operator, args.problem_size); + auto tc_C = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size); + auto tc_D = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size); + + auto size_C = tc_C.at(0) * tc_C.at(1) * tc_C.at(2) * tc_C.at(3); + if (args.tensor_c_numel >= 0 && args.tensor_c_numel == tc_C.at(3) && args.tensor_c_numel < size_C) { + // C is interpreted as bias + tc_C = {0, 0, 0, 0}; + } + + cutlass::TensorRef tref_A(args.ptr_A, LayoutA::packed(tc_A)); + cutlass::TensorRef tref_B(args.ptr_B, LayoutB::packed(tc_B)); + cutlass::TensorRef tref_C(args.ptr_C, LayoutC::packed(tc_C)); + cutlass::TensorRef tref_D(args.ptr_D, LayoutC::packed(tc_D)); + + return { + args.problem_size, + tref_A, + tref_B, + tref_C, + tref_D, + args.epilogue_params, + static_cast(args.split_k_mode) + }; + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}_TemporaryArgs args, int *semaphore=nullptr) { + auto arguments = construct_arguments(args); + typename ${operation_name}${operation_suffix}::Params* params; + params = new ${operation_name}${operation_suffix}::Params(arguments, semaphore); + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + dim3 ${operation_name}_get_grid_shape( + int conv_kind, + cutlass::conv::Conv2dProblemSize problem_size, + cutlass::gemm::GemmCoord tile_size, + int split_k_slices + ) { + + using Swizzle = typename ${operation_name}_base::ThreadblockSwizzle; + auto tiled_shape = Swizzle::get_tiled_shape( + static_cast(conv_kind), + problem_size, + tile_size, + split_k_slices); + + return Swizzle::get_grid_shape(tiled_shape); + } + + size_t ${operation_name}_get_workspace_size(${operation_name}_TemporaryArgs args) { + auto arguments = construct_arguments(args); + + // Temporarily define device::-level Conv2d so that we can call get_workspace_size + using DeviceConv = cutlass::conv::device::ImplicitGemmConvolution<${operation_name}_base>; + return DeviceConv::get_workspace_size(arguments); + } +} + + """ + + def __init__(self, operation: "Conv2dOperation"): + super().__init__(operation) + self.extra_funcs = { + "get_grid_shape": dim3_, + "get_workspace_size": ctypes.c_uint64 + } + self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor) + self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p] + self.conv_kind = operation.conv_kind + + self.operation: Conv2dOperation = operation + + self.emitter = EmitConv2dInstance("_type") + + self.threads = operation.tile_description.num_threads + + self.swizzle_functor = operation.swizzling_functor + + def emit(self): + return self.emitter.emit(self.operation) + + def plan(self, arguments: Conv2dArguments): + tile_size = GemmCoord( + self.operation.tile_description.threadblock_shape[0], + self.operation.tile_description.threadblock_shape[1], + self.operation.tile_description.threadblock_shape[2], + ) + + grid = self.get_grid_shape( + int(self.conv_kind), + arguments.problem_size.ctype, + tile_size.ctype, + arguments.split_k_slices + ) + + return LaunchConfiguration( + [grid.x, grid.y, grid.z], [self.threads, 1, 1], + self.shared_memory_capacity) + + def initialize(self): + err, = cuda.cuFuncSetAttribute( + self.kernel, + attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + value=self.shared_memory_capacity) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error: {err}") + + +class Conv2dOperation: + """ + CUTLASS Conv2d operation description. + + :param conv_kind: convolution operator + :type conv_kind: :class:`cutlass_library.library.ConvKind` + + :param iterator_algorithm: Selects among several implementation + variants trading off performance with simplicity + :type iterator_algorithm: :class:`cutlass_library.library.IteratorAlgorithm` + + :param arch: GPU compute capability (sm_xx) + :type arch: int + + :param tile_description: tile description + :type tile_description: :class:`cutlass_cppgen.backend.TileDescription` + + :param A: tensor A description + :type A: :class:`cutlass_cppgen.backend.TensorDescription` + + :param B: tensor B description + :type B: :class:`cutlass_cppgen.backend.TensorDescription` + + :param C: tensor C description + :type C: :class:`cutlass_cppgen.backend.TensorDescription` + + :param D: tensor D description + :type D: :class:`cutlass_cppgen.backend.TensorDescription` + + :param element_epilogue: element type for computation in epilogue \ + :type element_epilogue: cutlass_library.library.DataType + + :param stride_support: distinguish among partial specializations that \ + accelerate certain problems where convolution stride is unit \ + :type stride_support: :class:`cutlass_library.library.StrideSupport` + + :param epilogue_functor: convolution epilogue functor + :type epilogue_functor: :class:`EpilogueFunctor` + + :param swizzling_functor: threadblock swizzling functor + """ + def __init__( + self, + conv_kind, + iterator_algorithm, + arch: int, + tile_description: TileDescription, + A: TensorDescription, + B: TensorDescription, + C: TensorDescription, + stride_support, + epilogue_functor, + swizzling_functor=SwizzlingFunctor.Identity1, + emission_type=EmissionType.Kernel, + **kwargs + ): + self.operation_kind: OperationKind = OperationKind.Conv2d + self.arch: int = arch + self.tile_description: TileDescription = tile_description + self.conv_kind = conv_kind + self.A: TensorDescription = A + self.B: TensorDescription = B + self.C: TensorDescription = C + self.epilogue_functor = epilogue_functor + self.iterator_algorithm = iterator_algorithm + self.stride_support = stride_support + self.swizzling_functor = swizzling_functor + + self.emission_type = emission_type + + self.rt_module: Conv2dRT = Conv2dRT(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type + + def run(self, arguments: Conv2dArguments) -> cuda.CUresult: + """ + Launch the cuda kernel with input arguments + + :param arguments: conv2d arguments + :type arguments: :class:`cutlass_cppgen.backend.Conv2dArguments` + """ + + # launch the kernel + err = self.rt_module.run( + arguments.host_workspace, + arguments.device_workspace, + arguments.launch_config, + arguments.stream + ) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {err}") + + return err + + # + # Get function name + # + + def procedural_name(self): + """The full procedural name indicates architecture, extended name, tile size, and layout.""" + return self.configuration_name() + + def configuration_name(self): + """The full procedural name indicates architecture, extended name, tile size, and layout.""" + + opcode_class_name = OpcodeClassNames[ + self.tile_description.math_instruction.opcode_class + ] + + threadblock = "%dx%d_%dx%d" % ( + self.tile_description.threadblock_shape[0], + self.tile_description.threadblock_shape[1], + self.tile_description.threadblock_shape[2], + self.tile_description.stages, + ) + + if self.stride_support == StrideSupport.Unity: + configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}" + else: + configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}" + + return SubstituteTemplate( + configuration_name, + { + "arch": str(self.arch), + "opcode_class": opcode_class_name, + "extended_name": self.extended_name(), + "threadblock": threadblock, + "layout": self.layout_name(), + "alignment": "%d" % self.A.alignment + }, + ) + + def extended_name(self): + """Append data types if they differ from compute type.""" + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), + }) + + return extended_name + + def layout_name(self): + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + def core_name(self): + """The basic operation kind is prefixed with a letter indicating the accumulation type.""" + + intermediate_type = "" + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: + inst_shape = "%dx%dx%d" % tuple( + self.tile_description.math_instruction.instruction_shape) + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.accumulator_type(): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + else: + inst_shape = "" + + return "%s%s%s%s_%s" % ( + ShortDataTypeNames[self.accumulator_type()], + inst_shape, + intermediate_type, + ConvKindNames[self.conv_kind], + IteratorAlgorithmNames[self.iterator_algorithm] + ) + + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + def device_op(self): + """ + Returns a new Conv2dOperation object that is constructed with emission type + ``EmissionType.Device``. + + :return: operation ready for device-level code emission + :rtype: Conv2dOperation + """ + return Conv2dOperation( + self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description, + self.A, self.B, self.C, self.stride_support, self.epilogue_functor, self.swizzling_functor, + emission_type=EmissionType.Device) + + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + + +class EmitConv2dInstance: + def __init__(self, operation_suffix=""): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/conv/kernel/default_conv2d_fprop.h", + "cutlass/conv/kernel/default_conv2d_dgrad.h", + "cutlass/conv/kernel/default_conv2d_wgrad.h", + "cutlass/conv/device/implicit_gemm_convolution.h" + ] + self.template = """ +// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" +using ${operation_name}_base = +typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} +>::Kernel; + +struct ${operation_name}${operation_suffix}: + public ${operation_name}_base { }; + +""" + + self.template_device = """ +// Conv2d operation ${operation_name} + +using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} +>::Kernel; + +using DeviceKernel = + typename cutlass::conv::device::ImplicitGemmConvolution; +""" + + def emit(self, operation): + warp_shape = [int(operation.tile_description.threadblock_shape[idx] / + operation.tile_description.warp_count[idx]) for idx in range(3)] + + epilogue_vector_length = int(min( + operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "conv_kind": ConvKindTag[operation.conv_kind], + "conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), + "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), + "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), + "epilogue_vector_length": str(epilogue_vector_length), + "epilogue_functor": operation.epilogue_functor.emit(), + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm], + "iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), + "stride_support": StrideSupportTag[operation.stride_support], + "math_operator": "cutlass::arch::OpMultiplyAddComplex" if operation.is_complex() else MathOperationTag[operation.tile_description.math_instruction.math_operation], + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + } + + if operation.emission_type == EmissionType.Kernel: + conv2d_template = self.template + else: + conv2d_template = self.template_device + + return SubstituteTemplate(conv2d_template, values) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/epilogue.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/epilogue.py new file mode 100644 index 0000000000000000000000000000000000000000..49ad79c9c8ecc9cad6067a3d9543b2625344848b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/epilogue.py @@ -0,0 +1,541 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ctypes + +from cutlass_library import SubstituteTemplate +import numpy as np + +from cutlass_library import DataType, DataTypeTag +from cutlass_cppgen.backend.c_types import MatrixCoord_, tuple_factory +from cutlass_cppgen.backend.frontend import NumpyFrontend +from cutlass_cppgen.backend.library import ActivationOp, ActivationOpTag +from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor + +dtype2ctype = { + DataType.f16: ctypes.c_uint16, + DataType.bf16: ctypes.c_uint16, + DataType.f32: ctypes.c_float, + DataType.f64: ctypes.c_double, + DataType.s8: ctypes.c_int8, + DataType.s32: ctypes.c_int32 +} + +if is_torch_available(): + import torch + import torch.nn.functional as F + + +def get_scalar(value): + """ + Returns a scalar value from a container (e.g., np.ndarray) + """ + if is_numpy_tensor(value): + if value.size != 1: + raise Exception("Scalars used in epilogue must be of size 1") + return value.reshape(-1)[0] + elif is_torch_tensor(value): + if value.size != 1: + raise Exception("Scalars used in epilogue must be of size 1") + return value.reshape(-1)[0] + else: + return value + + +def to_ctype_value(value, dtype): + """ + Converts ``value`` to the corresponding storage needed for the ctype that + will store ``value``. + """ + scalar = get_scalar(value) + if dtype == DataType.f16: + # Convert f16 value into an integer + return int.from_bytes(np.float16(scalar).tobytes(), "little") + else: + return scalar + + +################################################################################################# +# +# Epilogue Functors +# +################################################################################################# + + +class EpilogueFunctorBase: + """ + Base class for thread-level epilogue functors + """ + + def __init__(self) -> None: + pass + + def emit(self, tag, template_argument): + template = """${tag}<${arguments}>""" + arguments = "" + for idx, arg in enumerate(template_argument): + arguments += arg + if idx < len(template_argument) - 1: + arguments += ", " + values = { + "tag": tag, + "arguments": arguments, + } + + return SubstituteTemplate(template, values) + + +class LinearCombination(EpilogueFunctorBase): + """ + Apply a linear combination operator to an array of elements + D = alpha * accumulator + beta * source + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes + when there are not enough data to store + + :param element_accumulator: Accumulator data type + + :param element_epilogue: data type used to compute linear combination + """ + + tag = "cutlass::epilogue::thread::LinearCombination" + + def __init__( + self, element_output, epilogue_vector_length, + element_accumulator=None, element_epilogue=None) -> None: + super().__init__() + + if element_accumulator is None: + element_accumulator = element_output + if element_epilogue is None: + element_epilogue = element_output + + self.element_output = element_output + self.element_accumulator = element_accumulator + self.element_epilogue = element_epilogue + self.epilogue_vector_length = epilogue_vector_length + + self.template_arguments = [ + DataTypeTag[element_output], + str(epilogue_vector_length), + DataTypeTag[element_accumulator], + DataTypeTag[element_epilogue], + ] + + c_element_epilogue = dtype2ctype[self.element_epilogue] + element_epilogue = self.element_epilogue + + class _EpilogueOutputOpParamsEVT(ctypes.Structure): + """ + Epilogue params when using the default linear combination of EVT, which + does not currently use {alpha,beta}_ptr_array + """ + + stride_type = tuple_factory((0,0,1), "int64_t", [0]) + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ("dalpha", stride_type), + ("dbeta", stride_type), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ("alpha_ptr_array", ctypes.c_void_p), + ("beta_ptr_array", ctypes.c_void_p), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + + def to_evt_params(self) -> _EpilogueOutputOpParamsEVT: + return _EpilogueOutputOpParamsEVT(self.alpha, self.beta) + + self.epilogue_type = _EpilogueOutputOpParams + self.epilogue_type_evt = _EpilogueOutputOpParamsEVT + + def emit(self): + return super().emit(self.tag, self.template_arguments) + + +class LinearCombinationClamp(LinearCombination): + """ + Applies a linear combination operator to an array of elements then clamps + the output before converting to the output element type. + + D = alpha * accumulator + beta * source + uniform + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes + when there are not enough data to store + + :param element_accumulator: Accumulator data type + + :param element_epilogue: data type used to compute linear combination + """ + + tag = "cutlass::epilogue::thread::LinearCombinationClamp" + + def __init__( + self, element_output, epilogue_vector_length, + element_accumulator=None, element_epilogue=None) -> None: + # Base constructor + super().__init__( + element_output, + epilogue_vector_length, + element_accumulator, + element_epilogue, + ) + + c_element_epilogue = dtype2ctype[self.element_epilogue] + element_epilogue = self.element_epilogue + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + + self.epilogue_type = _EpilogueOutputOpParams + + +class FastLinearCombinationClamp(EpilogueFunctorBase): + """ + Applies a linear combination operator to an array of elements then clamps + the output before converting to the output element type. + + D = alpha * accumulator + beta * source + + Note: The below method only when problem_size_K <= 256 for signed int8 gemm + or problem_size_K <= 128 for unsigned int8 gemm. The default approach is + above. + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes + when there are not enough data to store + """ + + tag = "cutlass::epilogue::thread::FastLinearCombinationClamp" + + def __init__(self, element_output, epilogue_vector_length, *args) -> None: + super().__init__() + + self.template_arguments = [ + DataTypeTag[element_output], str(epilogue_vector_length) + ] + + self.element_accumulator = DataType.s32 + self.element_epilogue = DataType.f32 + + # get epilogue output op + c_element_epilogue = dtype2ctype[self.element_epilogue] + element_epilogue = self.element_epilogue + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + + self.epilogue_type = _EpilogueOutputOpParams + + def emit(self): + return super().emit(self.tag, self.template_arguments) + + +class LinearCombinationGeneric(LinearCombination): + """ + Applies a linear combination operator followed by an activation function + to an array of elements. + + D = activation(alpha * accumulator + beta * source) + + :param activation_functor: input activation functor + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes + when there are not enough data to store + + :param element_accumulator: Accumulator data type + + :param element_epilogue: data type used to compute linear combination + """ + + tag = "cutlass::epilogue::thread::LinearCombinationGeneric" + + def __init__( + self, activation_functor, + element_output, epilogue_vector_length, + element_accumulator=None, element_epilogue=None) -> None: + super().__init__( + element_output, + epilogue_vector_length, + element_accumulator, + element_epilogue, + ) + + self.template_arguments = [ + activation_functor.emit()] + self.template_arguments + + self.activation_functor = activation_functor + self.element_epilogue = element_epilogue + + # get epilogue output op + self.epilogue_type = self.activation_functor.epilogue_output_op(self.element_epilogue) + + +class ActivationFunctor: + """ + Base class for frequently used activation functions + """ + + @staticmethod + def numpy(x: np.ndarray): + raise NotImplementedError() + + @classmethod + def emit(cls): + return ActivationOpTag[cls.binding_type] + + @staticmethod + def epilogue_output_op(element_epilogue): + c_element_epilogue = dtype2ctype[element_epilogue] + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + + return _EpilogueOutputOpParams + +class ActivationMeta(type): + @classmethod + def __call__(cls, x, *args): + if is_numpy_tensor(x): + return cls.numpy(x, *args) + elif is_torch_tensor(x): + return cls.torch(x, *args) + else: + raise NotImplementedError("Unsupported tensor type") + + @classmethod + def numpy(cls, *args): + raise NotImplementedError(f"Numpy reference for {cls.__name__[:-4]} is not implemented.") + + @classmethod + def torch(cls, *args): + raise NotImplementedError(f"PyTorch reference for {cls.__name__[:-4]} is not implemented.") + +############################################################################## +# identity operator +class identityMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + return x + + @classmethod + def torch(cls, x): + return x + +class identity(ActivationFunctor, metaclass=identityMeta): + binding_type = ActivationOp.Identity + + +############################################################################## +# ReLu operator +class reluMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + return np.where(x > 0, x, 0) + + @classmethod + def torch(cls, x): + return F.relu(x) + +class relu(ActivationFunctor, metaclass=reluMeta): + binding_type = ActivationOp.ReLU + + +############################################################################## +# Leaky ReLu operator +class leakyReLUMeta(ActivationMeta): + @classmethod + def numpy(cls, x, leaky_alpha): + return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha + + @classmethod + def torch(cls, x, leaky_alpha): + return F.leaky_relu(x, leaky_alpha) + +class leaky_relu(ActivationFunctor, metaclass=leakyReLUMeta): + binding_type = ActivationOp.LeakyReLU + + @staticmethod + def epilogue_output_op(element_epilogue): + c_element_epilogue = dtype2ctype[element_epilogue] + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ("leaky_alpha", c_element_epilogue) + ] + + def __init__(self, alpha, beta, leaky_alpha=0.2, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + self.alpha_ptr = 0 + self.beta_ptr = 0 + self.leaky_alpha = to_ctype_value(leaky_alpha, element_epilogue) + + return _EpilogueOutputOpParams + + +############################################################################## +# Tanh operator +class tanhMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + return np.tanh(x) + + @classmethod + def torch(cls, x): + return torch.tanh(x) + +class tanh(ActivationFunctor, metaclass=tanhMeta): + binding_type = ActivationOp.Tanh + + +############################################################################## +# Sigmoid operator +class sigmoidMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + return 1.0 / (1.0 + np.exp(-x)) + + @classmethod + def torch(cls, x): + return F.sigmoid(x) + +class sigmoid(ActivationFunctor, metaclass=sigmoidMeta): + binding_type = ActivationOp.Sigmoid + + +############################################################################## +# SiLu operator +class siluMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + return x * sigmoidMeta.numpy() + + @classmethod + def silu(cls, x): + return F.silu(x) + + +class silu(ActivationFunctor, metaclass=siluMeta): + binding_type = ActivationOp.SiLU + + +############################################################################## +# Hardswish operator +class hardswishMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + relu6 = np.minimum(np.maximum(x + 3.0, 0), 6.0) + return x * relu6 / 6.0 + + @classmethod + def torch(cls, x): + return F.hardswish(x) + + +class hardswish(ActivationFunctor, metaclass=hardswishMeta): + binding_type = ActivationOp.HardSwish + + +############################################################################## +# GELU operator +class geluMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + from scipy.special import erf + return 0.5 * x * (1 + erf(x / np.sqrt(2.0))) + + @classmethod + def torch(cls, x): + return F.gelu(x) + + +class gelu(ActivationFunctor, metaclass=geluMeta): + binding_type = ActivationOp.Gelu diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b61e983ab23bb5662d15e185184efa227351446d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/__init__.py @@ -0,0 +1,34 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.evt.epilogue import EpilogueFunctorVisitor +from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..945dcf80e307eb870f31722822f959da03e6c421 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/__init__.py @@ -0,0 +1,38 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter +import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes +from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter +import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes +from cutlass_cppgen.backend.evt.backend.sm100_emitter import Sm100Emitter +import cutlass_cppgen.backend.evt.backend.sm100_nodes as sm100_nodes diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/emitter_base.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/emitter_base.py new file mode 100644 index 0000000000000000000000000000000000000000..72a7d8c04db5c8df2595fab8befaa07bf238c2f2 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/emitter_base.py @@ -0,0 +1,159 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Base class for Epilogue Visitor Emitter +""" + +from cutlass_library import DataTypeTag +from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR + + +class FusionCallbacks: + def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None: + """ + Emit the EVT fusion callbacks + :param dag_ir: the DAG IR holding the epilogue visitor + :param cc: compute capability + :param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks + For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API + so that their shared memory can be explicitly reused + For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes. + """ + self.dag_ir = dag_ir + self.emit_CD = emit_CD + self.cc = cc + self.evt_cc = 90 if cc >= 90 else cc + if self.cc < 90: + self.namespace = "threadblock" + else: + self.namespace = "fusion" + + # + # Helper functions + # + + def get_visitor_name(self, node: str): + """ + Get the visitor name + """ + meta = self.dag_ir.get_node_meta(node) + if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0: + return f"EVT{meta.name_camel}" + else: + return meta.name_camel + + def emit(self): + node_metas = self.dag_ir.node_metas_topological_order() + epilogue_str = "" + # Step 1: emit individual node type decl + # emit the EVT & DAG connector + for meta in node_metas: + if not meta.disabled: + epilogue_str += self.emit_node(meta) + if not self.emit_CD and meta.name == "D": + continue + if isinstance(meta, TopoVisitorNode): + epilogue_str += self.emit_dag(meta) + else: + epilogue_str += self.emit_evt(meta) + + # Step 2: post-processing & get callback name + if not self.emit_CD: + if not self.dag_ir.has_node("C"): + epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n" + output_node = self.dag_ir.get_all_inputs("D")[0] + # The callback is the src of node D + callback_name = self.get_visitor_name(output_node) + else: + # The callback is the last node in the topological order + callback_name = self.get_visitor_name(node_metas[-1].name) + return epilogue_str, callback_name + + def emit_evt(self, node): + if self.dag_ir.in_degree(node.name) == 0: + return "" + + evt_tmp = f""" +using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT< + {node.name_camel}, +""" + sorted_children = self.dag_ir.get_all_inputs(node.name) + evt_node_strs = [f" {self.get_visitor_name(child_name)}" for child_name in sorted_children] + evt_tmp += ",\n".join(evt_node_strs) + ">;\n" + + return evt_tmp + + def emit_dag(self, node): + subgraph = node.subgraph + subgraph_nodes = subgraph.nodes_topological_order() + # Emit the Edge Tuple + edge_tuples = "cute::tuple<\n" + for n in subgraph_nodes[:-1]: + in_edges = subgraph.in_edges(n) + edge_weights = [subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges] + sorted_children = [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))] + edge_tuple = " cute::seq<" + edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children] + edge_tuple += ", ".join(edge_str) + ">,\n" + + edge_tuples += edge_tuple + edge_tuples += " >" + + # Emit the node list + dag_nodes = "" + dag_node_strs = [] + for n in subgraph_nodes[:-1]: + n_meta = subgraph.get_node_meta(n) + if n_meta.disabled: + dag_node_strs.append(f" {self.get_visitor_name(n)}") + else: + dag_node_strs.append(f" {n_meta.name_camel}") + dag_nodes = ",\n".join(dag_node_strs) + + return f""" +using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor< + {DataTypeTag[node.subgraph.element_compute]}, + {edge_tuples}, +{dag_nodes} +>; +""" + + def emit_node(self, node): + if isinstance(node, TopoVisitorNode): + emission = "" + for node in node.subgraph.node_metas_topological_order(): + if not node.disabled: + emission += self.emit_node(node) + return emission + else: + return node.underlying_impl.type_decl diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..db521e5279c57734a8e408938dc6ea95a608c6d8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py @@ -0,0 +1,116 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Emitter for Sm100 Epilogue Visitor +""" + +from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag +from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape +from cutlass_cppgen.backend import GemmOperationUniversal +from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks +from cutlass_cppgen.backend.evt.ir.node import TupleEmitter + + +class Sm100CollectiveEpilogue: + def __init__(self, tile_description, + kernel_schedule, + epilogue_schedule, + element_accumulator, + element_d, + fusion_callbacks) -> None: + + self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule) + self.element_accumulator = element_accumulator + if fusion_callbacks.dag_ir.has_node("C"): + self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element + else: + self.element_c = DataType.void + self.element_d = element_d + self.schedule = epilogue_schedule + self.fusion_callbacks = fusion_callbacks + self.opclass = tile_description.math_instruction.opcode_class + + @property + def CtaTileMNK(self) -> str: + """ + The threadblock shape + """ + return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>" + + @property + def EpilogueTileType(self) -> str: + """ + The epilogue tile type + """ + return "cutlass::epilogue::collective::EpilogueTileAuto" + + @property + def Schedule(self) -> str: + return EpilogueScheduleTag[self.schedule] + + def emit(self): + tuple_emitter = TupleEmitter("int64_t") + stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl + stride_C_str = stride_D_str + if self.fusion_callbacks.dag_ir.has_node("C"): + stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl + + callback_decl, callback_name = self.fusion_callbacks.emit() + return callback_name, f""" +using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor< + {OpcodeClassTag[self.opclass]}, + {self.CtaTileMNK}, {self.EpilogueTileType}, + {DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]}, + {self.Schedule}, {stride_C_str}, {stride_D_str}, + false /* IsPerColScaleSupported */, + false /* IsBlockScaleSupported */ +>; +{callback_decl} +""" + + +class Sm100Emitter: + def __init__(self, operation: GemmOperationUniversal, graph) -> None: + fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False) + + self.collective_epilogue = Sm100CollectiveEpilogue( + tile_description=operation.tile_description, + kernel_schedule=operation.tile_description.kernel_schedule, + epilogue_schedule=operation.tile_description.epilogue_schedule, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_d=fusion_callbacks.dag_ir.get_node_meta("D").element, + fusion_callbacks=fusion_callbacks + ) + + def emit(self): + return self.collective_epilogue.emit() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..33e77b4c9f2efbef808f8551e4402f5a6761ea4a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py @@ -0,0 +1,134 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from pycute import product + +from cutlass_library import DataTypeSize, DataTypeTag + +from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl +import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes + +from cutlass_cppgen.backend.library import FloatRoundStyleTag + + +Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl +Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl +Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl +Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl +Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl +Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl +Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl +Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl +Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl +Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl + + +class Sm100AuxLoadImpl(AuxLoadImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor;\n" + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128) + + +class Sm100AuxStoreImpl(AuxStoreImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f""" +using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor< + EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]} +>; +""" + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, + typename {self.descriptor}::CopyOpR2S +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..868453a7cf5049e5899bf6aef419485a1a5dbb43 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py @@ -0,0 +1,47 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Emitter for Sm80 Epilogue Visitor +""" + +from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks +from cutlass_cppgen.backend import GemmOperationUniversal + + +class Sm80Emitter: + def __init__(self, operation: GemmOperationUniversal, graph) -> None: + self.fusion_callbacks = FusionCallbacks(graph, cc=80) + + def emit(self): + callback_decl, callback_name = self.fusion_callbacks.emit() + return callback_name, callback_decl diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..b9fc561354a471f4f97600b27e4dbb21950a9e79 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py @@ -0,0 +1,258 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_library import DataTypeSize, DataTypeTag + +from cutlass_cppgen.backend.evt.ir import ( + # Load Node + AccumulatorImpl, + AuxLoadImpl, + ColumnBroadcastImpl, + LoadNode, + LoadSrcImpl, + RowBroadcastImpl, + ScalarBroadcastImpl, + # Compute Node + ComputeImpl, + # Store Node + AuxStoreImpl, + ColumnReductionImpl, + RowReductionImpl, + ScalarReductionImpl +) + +from cutlass_cppgen.backend.library import ( + FloatRoundStyleTag, + FunctionalOp, + op_tag, +) + + +class Sm80AccumulatorImpl(AccumulatorImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n""" + return self._type_decl + + +class Sm80AuxLoadImpl(AuxLoadImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad< + OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80LoadSrcImpl(Sm80AuxLoadImpl): + pass + + +class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl): + def __init__(self, node: LoadNode) -> None: + super().__init__(node) + self.broadcast_count = 1 + self.reduction_fn = FunctionalOp.Multiplies + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast< + {DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)} +>; +""" + return self._type_decl + + +class Sm80RowBroadcastImpl(RowBroadcastImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, {DataTypeTag[self.element]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, {DataTypeTag[self.element]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80ComputeImpl(ComputeImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute< + {op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]}, + {FloatRoundStyleTag[self.round_style]} +>; +""" + return self._type_decl + + +class Sm80AuxStoreImpl(AuxStoreImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80StoreDImpl(Sm80AuxStoreImpl): + pass + + +class Sm80ColumnReductionImpl(ColumnReductionImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, + OutputTileThreadMap, {DataTypeTag[self.element]}, + {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80RowReductionImpl(RowReductionImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, + OutputTileThreadMap, {DataTypeTag[self.element]}, + {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80ScalarReductionImpl(ScalarReductionImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, + OutputTileThreadMap, {DataTypeTag[self.element]}, + {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..3c058aa8f30a56d97ce3c3600f7c89189e7a15ad --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py @@ -0,0 +1,98 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Emitter for Sm90 Epilogue Visitor +""" + +from cutlass_library import DataTypeTag, EpilogueScheduleTag +from cutlass_cppgen.backend import GemmOperationUniversal +from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks + + +class CollectiveEpilogue: + def __init__(self, tile_description, + schedule, + element_c, + element_d, + fusion_callbacks) -> None: + + self.cta_tile_mnk = tile_description.threadblock_shape + self.element_c = element_c + self.element_d = element_d + self.schedule = schedule + self.fusion_callbacks = fusion_callbacks + + @property + def CtaTileMNK(self) -> str: + """ + The threadblock shape + """ + return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>" + + @property + def EpilogueTileType(self) -> str: + """ + The epilogue tile type + """ + return "cutlass::epilogue::collective::EpilogueTileAuto" + + @property + def Schedule(self) -> str: + return EpilogueScheduleTag[self.schedule] + + def emit(self): + callback_decl, callback_name = self.fusion_callbacks.emit() + return callback_name, f""" +using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< + {self.CtaTileMNK}, {self.EpilogueTileType}, + {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]}, + {self.Schedule} +>; +{callback_decl} +""" + + +class Sm90Emitter: + def __init__(self, operation: GemmOperationUniversal, graph) -> None: + fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False) + + self.collective_epilogue = CollectiveEpilogue( + tile_description=operation.tile_description, + schedule=operation.tile_description.epilogue_schedule, + element_c=operation.C.element, + element_d=operation.C.element, + fusion_callbacks=fusion_callbacks + ) + + def emit(self): + return self.collective_epilogue.emit() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..43601a424e3ecb175837fb31389436c1470d9c0b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py @@ -0,0 +1,329 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from pycute import product + +from cutlass_library import DataTypeSize, DataTypeTag +from cutlass_cppgen.backend.evt.ir import ( + # Load Node + AccumulatorImpl, + AuxLoadImpl, + ColumnBroadcastImpl, + LoadNode, + LoadSrcImpl, + RowBroadcastImpl, + ScalarBroadcastImpl, + # Compute Node + ComputeImpl, + ComputeNode, + # Store Node + AuxStoreImpl, + ColumnReductionImpl, + RowReductionImpl, + ScalarReductionImpl, + StoreNode, + StoreDImpl, +) +from cutlass_cppgen.backend.library import ( + FloatRoundStyleTag, + FunctionalOp, + op_tag, +) + + +class Sm90AccumulatorImpl(AccumulatorImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n""" + return self._type_decl + + +class Sm90LoadSrcImpl(LoadSrcImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using ElementC = {DataTypeTag[self.element]}; +using StrideC = {self.stride_mnl}; +using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>; +""" + return self._type_decl + + +class Sm90AuxLoadImpl(AuxLoadImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor;\n" + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128) + + +class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl): + def __init__(self, node: LoadNode) -> None: + super().__init__(node) + self.broadcast_count = 1 + self.reduction_fn = FunctionalOp.Multiplies + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast< + {DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)} +>; +""" + return self._type_decl + + +class Sm90RowBroadcastImpl(RowBroadcastImpl): + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm90ComputeImpl(ComputeImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute< + {op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]}, + {FloatRoundStyleTag[self.round_style]} +>; +""" + return self._type_decl + + +class Sm90AuxStoreImpl(AuxStoreImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f""" +using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor< + EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]} +>; +""" + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, + typename {self.descriptor}::CopyOpR2S +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128) + + +class Sm90StoreDImpl(StoreDImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + return f""" +using ElementD = {DataTypeTag[self.element]}; +using StrideD = {self.stride_mnl}; +""" + + +class Sm90ColumnReductionImpl(ColumnReductionImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0, + typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, + {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm90RowReductionImpl(RowReductionImpl): + + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */, + typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, + {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm90ScalarReductionImpl(ScalarReductionImpl): + + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, + {DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]}, + {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl} +>; +""" + return self._type_decl diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py new file mode 100644 index 0000000000000000000000000000000000000000..da446e76d9ebd9de04950a89b2451480492147a9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py @@ -0,0 +1,168 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Epilogue Visitor interface for compiling, and running visitor-based epilogue. +""" + +import ctypes + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +from cutlass_library import DataType +import numpy as np + +from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase +import cutlass_cppgen.backend.evt.backend +from cutlass_cppgen.backend.frontend import TensorFrontend +from cutlass_cppgen.utils.datatypes import is_numpy_tensor +from cutlass_cppgen.backend.evt.passes.util import cc_map + + +class EpilogueFunctorVisitor(EpilogueFunctorBase): + """ + Apply an epilogue functor described by the epilogue EVT + + :param cc: compute capability + :param visitor_frontend: user-provide visitor frontend + + """ + def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None: + # Type of Emitter based on CC + self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter") + + # Visitor Types + self.visitor = visitor + self.graph = visitor.dag_ir + + # Data types + self.element_epilogue = element_compute # element compute + self.element_output = self.graph.get_node_meta('D').underlying_impl.element + + # Epilogue Thread Type + epilogue_thread_type = self.visitor.epilogue_thread_type + if cc_map[cc] in [90, 100]: + self.arg_c_type = self.visitor.arg_c_type + self.arg_d_type = self.visitor.arg_d_type + output_names = self.visitor.return_names + reduction_names = self.visitor.reduction_names + + # Epilogue stages specialized for sm80 kernel + if cc == 80: + if hasattr(self.visitor, "epilogue_stages"): + self.epilogue_stages = self.visitor.epilogue_stages + assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue" + + # Epilogue Argument Type + class _Arguments(ctypes.Structure): + """ + Concepts: + class _EpilogueArguments(ctypes.Structure): + _fields_ = [ + ("epilogue", _Arguments), <- this class + ("ptr_C", ctypes.c_void_p), + ("stride_C", StrideBatched_), + ("ptr_D", ctypes.c_void_p), + ("stride_D", StrideBatched_) + ] + """ + _fields_ = [ + ("output_op", epilogue_thread_type) + ] + + def __init__(self, kwargs: dict) -> None: + # The user-input kwargs is a dict of (name: tensors) + # We first convert all of them to device pointers + ptr_kwargs = {} + for key in kwargs.keys(): + is_output = key in output_names and key not in reduction_names + ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output) + # Initialize the thread arguments + self.output_op = epilogue_thread_type(ptr_kwargs) + + def get_tensor_ptr(self, tensor_name, kwargs, is_output=False): + """ + Helper function for extracting device pointer + """ + # Skip the special tensors + if cc in [90, 100]: + if tensor_name in ["C", "D"]: + return 0 + if tensor_name not in kwargs.keys(): + raise ValueError(f"Tensor {tensor_name} is not provided.") + tensor = kwargs[tensor_name] + + # For float scalar constant, directly return the value + if isinstance(tensor, float): + return tensor + + # The tensor frontend returns a device buffer for np.ndarray + # and device ptr for other frontends + buffer_or_ptr = TensorFrontend.argument(tensor, is_output) + if is_numpy_tensor(tensor): + # Remember the host tensor for later synchronization + setattr(self, f"{tensor_name}_buffer", buffer_or_ptr) + setattr(self, f"{tensor_name}_host", tensor) + return int(buffer_or_ptr.ptr) + else: + return int(buffer_or_ptr) + + def sync(self): + """ + Synchronize the results from device to host + """ + for name in output_names: + if hasattr(self, f"{name}_host"): + host_tensor = getattr(self, f"{name}_host") + tensor_ptr = getattr(self, f"{name}_buffer").ptr + (err,) = cuda.cuMemcpyDtoH( + host_tensor, + tensor_ptr, + host_tensor.size * host_tensor.itemsize, + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + self.epilogue_type = _Arguments + + def emit(self, operation): + """ + Emit the C++ code + """ + emitter = self.emit_cls(operation, self.graph) + return emitter.emit() + + def get_smem_size(self, tile_description): + """ + Get the shared memory size in bytes + """ + return self.visitor.get_smem_size(tile_description) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2323278ed232adea205e41b901c62a268e56976 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py @@ -0,0 +1,33 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py new file mode 100644 index 0000000000000000000000000000000000000000..213aafdbe3f922f22186e37ac9f2eefea74e71ce --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py @@ -0,0 +1,272 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Base class for Python EVT Frontend +""" + +from typing import Union + +from cutlass_library import DataType +from cutlass_cppgen.backend.evt.ir import ( + ComputeNode, + DAGIR, + LayoutNode, + LoadNode, + StoreNode, +) +from cutlass_cppgen.backend.evt.passes import ( + EVTGraphDrawer, + EVTPassManager, + GetSmemSize, + PassDAG2Tree, + PassGetArgumentType, + PassGetImpl, + PassFixElementD, + PassLayoutManipulateElimination, + PassPreprocessRed, + PassShapeTypePropagation, +) +from cutlass_cppgen.backend.evt.passes.util import cc_map +from cutlass_cppgen.backend.utils import device_cc +from cutlass_cppgen.epilogue.evt_ops import permute, reshape +from cutlass_cppgen.utils.datatypes import library_type + + +class EVTFrontendBase: + layout_fns = { + "permute": permute, + "reshape": reshape + } + + def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None: + self.cc = cc + self.element_compute = library_type(element_compute) + self.dag_ir = DAGIR(self.cc, self.element_compute) + self.compute_cnt = 0 + self.layout_cnt = 0 + self.imm_cnt = 0 + + self.pass_manager = EVTPassManager( + self.dag_ir, + [ + PassPreprocessRed, + PassGetArgumentType, + PassShapeTypePropagation, + PassLayoutManipulateElimination, + PassGetImpl, + PassDAG2Tree, + PassFixElementD + ] + additional_passes) + + if self.cc == 80: + self._epilogue_stages = 1 + else: + self._epilogue_stages = None + + @property + def epilogue_stages(self): + return self._epilogue_stages + + @epilogue_stages.setter + def epilogue_stages(self, stages): + self._epilogue_stages = stages + + + def parse(self, *args, **kwargs): + raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class") + + def trace(self, *args, **kwargs): + # Parse the input + self.parse(*args, **kwargs) + + # Verify the DAG IR to ensure that "D" is the output node with out_degree = 0 + if (self.cc >= 90): + if (self.dag_ir.out_degree("D") != 0): + raise RuntimeError( + f"On SM90 or higher, D is expected to be a output node with 0 users to " + f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}") + + # Run the passes + self.pass_manager() + # Set the epilogue type + self.epilogue_thread_type = self.dag_ir.epilogue_thread_type + if cc_map[self.cc] in [90, 100]: + self.arg_c_type = self.dag_ir.arg_c_type + self.arg_d_type = self.dag_ir.arg_d_type + self.reduction_names = self.dag_ir.reduction_names + + # + # Helper functions for DAG IR manipulation + # + + def add_node(self, node): + self.dag_ir.add_node(node) + + def add_edge(self, src, tgt, weight=0): + self.dag_ir.add_edge(src, tgt, weight=weight) + + def set_tensor(self, node_name, example): + """ + Add an example tensor to node {node_name} in the DAG IR + """ + meta = self.dag_ir.get_node_meta(node_name) + meta.tensor = {"tensor": example} + + def set_store_tensor(self, node_name, example): + """ + Add an example tensor to node {node_name} in the DAG IR + """ + meta = self.dag_ir.get_node_meta(node_name) + meta.store_tensor = {"tensor": example} + + def mark_output(self, node_name): + """ + Mark a store node as output + """ + meta = self.dag_ir.get_node_meta(node_name) + if not isinstance(meta, StoreNode): + raise ValueError( + f"Only StoreNodes can be marked as output. " + f"Got {type(meta).__name__}: {node_name}") + meta.is_output = True + + # Add node with specific type + + def add_load_node(self, name, example): + """ + Add a Load node to DAG IR + :param name: name of the loaded variable + :type name: str + :param example: example input + :type example: np.ndarray|torch.Tensor|cupy.ndarray|float + """ + if name is None: + raise ValueError(f"Name is not provided.") + if example is None: + raise ValueError(f"Example input for {name} is not provided.") + load_node = LoadNode(name) + load_node.tensor = {"tensor": example} + # Special logics for accumulator + if name == "accum": + if load_node.tensor.rank == 2: + new_shape = tuple([1, ] + list(load_node.tensor.shape)) + load_node.tensor.broadcast(new_shape) + elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3: + raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.") + self.add_node(load_node) + + def add_imm(self, value: Union[float,int]): + """ + Add an immediate scalar value to DAG IR + :param value: the value of the immediate scalar + :type value: float + """ + try: + value = float(value) + except: + raise ValueError(f"{type(value).__name__} cannot be converted to float.") + + name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_') + self.imm_cnt += 1 + load_node = LoadNode(name) + load_node.tensor = {"tensor": value, "is_constant": True} + self.add_node(load_node) + return name + + def add_compute_node(self, op, name=None): + """ + Add a compute node. + :param op: the computation op + :param name: the node name (optional) + :type name: str + :return: the name of the compute node + """ + if name is None: + name = f"compute_{self.compute_cnt}" + self.compute_cnt += 1 + compute_node = ComputeNode( + name=name, fn=op, + element_output=self.element_compute, + element_compute=self.element_compute) + self.add_node(compute_node) + return compute_node.name + + def add_layout_node(self, op, kwargs, name=None): + """ + Add a layout node. + :param op: the layout op + :type op: evt_ops + :param name: the node name (optional) + :type name: str + :return: the name of the layout node + """ + if name is None: + name = f"layout_{self.layout_cnt}" + self.layout_cnt += 1 + layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs) + self.add_node(layout_node) + return layout_node.name + + def add_store_node(self, name): + store_node = StoreNode(name) + self.add_node(store_node) + + # + # Visualization The DAG IR + # + + def visualize(self, name="dag_ir"): + """ + Visualize the dag ir with svg file + :param name: the name of the graph + """ + drawer = EVTGraphDrawer(self.dag_ir, name) + try: + for name, graph in drawer.get_dot_graph(): + graph.write_svg(f"./{name}.svg") + except: + raise RuntimeError( + "'dot' is not found in path. GraphDrawer is disabled. " + "Please install it with 'sudo apt-get install graphviz'." + ) + + # + # Get shared memory size + # + + def get_smem_size(self, tile_description): + """ + Get the shared memory size of the epilogue + """ + smem_size = GetSmemSize(self.dag_ir)(tile_description) + return smem_size diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py new file mode 100644 index 0000000000000000000000000000000000000000..8727b754cd2b9a557d45760cb0a24a43619a373f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py @@ -0,0 +1,194 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Python AST frontend that parses input into DAG IR +""" + +import ast +import inspect +import textwrap + +from cutlass_library import DataType + +import cutlass_cppgen +from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase +from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu +from cutlass_cppgen.backend.library import FunctionalOp + + +class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor): + def __init__(self, cc, element_compute=DataType.f32, **kwargs): + super().__init__(cc, element_compute, **kwargs) + # Flags + # If this state is True, visit_Constant returns values without creating imm node + self.no_imm = False + self.visiting_return = False + + def parse(self, example_inputs): + self.example_inputs = example_inputs + self.source = textwrap.dedent(inspect.getsource(self.__call__)) + self.ast = ast.parse(self.source) + self.visit(self.ast) + + # + # Helper functions + # + @staticmethod + def ast_op_to_bindings(op): + mapping = { + ast.Add: FunctionalOp.Plus, + ast.Sub: FunctionalOp.Minus, + ast.Mult: FunctionalOp.Multiplies, + ast.Div: FunctionalOp.Divides, + "maximum": FunctionalOp.Maximum, + "minimum": FunctionalOp.Minimum, + "identity": identity.binding_type, + "relu": relu.binding_type, + "tanh": tanh.binding_type, + "sigmoid": sigmoid.binding_type, + "silu": silu.binding_type, + "hardswish": hardswish.binding_type, + "gelu": gelu.binding_type, + "multiply_add": FunctionalOp.MultiplyAdd, + "sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd), + "max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum), + "exp": FunctionalOp.Exp + } + return mapping[op] + + # + # Visiting different node types + # + + def visit_FunctionDef(self, node: ast.FunctionDef): + # Visit args and register load nodes + for arg in node.args.args: + self.visit(arg) + for expr in node.body: + self.visit(expr) + + def visit_arg(self, node: ast.arg): + # Name of the argument + name = node.arg + try: + example_tensor = self.example_inputs[name] + except: + raise RuntimeError(f"Example input for {name} is not provided.") + + self.add_load_node(name, example_tensor) + + def visit_Name(self, node: ast.Name): + return node.id + + def visit_Constant(self, node: ast.Constant): + if self.no_imm: + return node.value + else: + name = self.add_imm(node.value) + return name + + def visit_Tuple(self, node: ast.Tuple): + results = [] + for elt in node.elts: + results.append(self.visit(elt)) + return tuple(results) + + def visit_keyword(self, node: ast.keyword): + return {node.arg: self.visit(node.value)} + + def visit_BinOp(self, node: ast.BinOp): + if self.visiting_return: + raise SyntaxError("Return value cannot be an expression") + lhs = self.visit(node.left) + rhs = self.visit(node.right) + op = self.ast_op_to_bindings(type(node.op)) + name = self.add_compute_node(op) + + # Add edges + # The edge weights are used to sort the input args + self.add_edge(lhs, name, weight=0) + self.add_edge(rhs, name, weight=1) + return name + + def visit_Assign(self, node: ast.BinOp): + target = self.visit(node.targets[0]) + value = self.visit(node.value) + # Create the assign node + self.add_store_node(target) + + # Add edges + self.add_edge(value, target) + return target + + def visit_Call(self, node: ast.Call): + if self.visiting_return: + raise SyntaxError("Return value cannot be an expression") + func = self.visit(node.func) + args = [self.visit(arg) for arg in node.args] + + if func in self.layout_fns.keys(): + # Parse kwargs + # By default, visiting imm automatically creates a load node + # However, in function call, keyword args are used to set + # specific function attributes such as indices for permute + # So no_imm is set to True temporarily + self.no_imm = True + kwargs = {} + for kw in node.keywords: + kwargs.update(self.visit(kw)) + self.no_imm = False + op = self.layout_fns[func] + name = self.add_layout_node(op, kwargs) + else: + op = self.ast_op_to_bindings(func) + name = self.add_compute_node(op) + + # Add edges + for idx, arg in enumerate(args): + self.add_edge(arg, name, weight=idx) + return name + + def visit_Return(self, node: ast.Return): + self.visiting_return = True + results = self.visit(node.value) + self.visiting_return = False + self.return_names = results + if not isinstance(results, tuple): + results = (results,) + for rst in results: + try: + example_tensor = self.example_inputs[rst] + except: + raise RuntimeError(f"Example input for {rst} is not provided.") + self.set_store_tensor(rst, example_tensor) + self.mark_output(rst) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f9e3f811a020164dc5ec5eb4a8dfaf3dc5728fe --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py @@ -0,0 +1,53 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl +from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR +from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode +from cutlass_cppgen.backend.evt.ir.load_nodes import ( + LoadNode, + AccumulatorImpl, + LoadSrcImpl, + AuxLoadImpl, + RowBroadcastImpl, + ColumnBroadcastImpl, + ScalarBroadcastImpl +) +from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl +from cutlass_cppgen.backend.evt.ir.store_nodes import ( + StoreNode, + StoreDImpl, + AuxStoreImpl, + ColumnReductionImpl, + RowReductionImpl, + ScalarReductionImpl +) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..02b05358648694dcf2a5afd7117e6fca6a2d136c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py @@ -0,0 +1,91 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Python registration for compute nodes in EVT +""" + +from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase +from cutlass_cppgen.backend.library import FloatRoundStyle + + +class ComputeImplBase(ImplBase): + """ + Base class for compute implementation + """ + def __init__(self, node) -> None: + super().__init__(node) + + +class ComputeImpl(ComputeImplBase): + """ + Implementation for Compute Node + """ + def __init__(self, node) -> None: + super().__init__(node) + + self.fn = node.fn + self.element_output = node.element_output + self.element_compute = node.element_compute + self.round_style = node.round_style + + @staticmethod + def match(node, problem_size: tuple): + return True + + +class ComputeNode(NodeBase): + """ + Compute Node in DAG IR + """ + possible_impls = [ + ComputeImpl + ] + def __init__( + self, name: str, fn, element_output, + element_compute, + round_style=FloatRoundStyle.ToNearest) -> None: + super().__init__(name) + self.op = "compute" + self.fn = fn + self.element_compute = element_compute + self.round_style = round_style + + def type_propagation(self, *args, **kwargs): + """ + Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`. + """ + self.element = self.element_compute + # In general, the compute nodes have element_output = element_compute + # In certain cases like producer of D it is overwritten by other passes + if not hasattr(self, "element_output"): + self.element_output = self.element diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e9f75a9727306d56c049bd491a95542a68bec8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py @@ -0,0 +1,254 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +DAG IR used by Python EVT +""" + +import networkx as nx + +from cutlass_library import DataType + +from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode +from cutlass_cppgen.backend.evt.ir.node import NodeBase +from cutlass_cppgen.backend.library import ActivationOp +from cutlass_cppgen.backend.utils import device_cc + + +class DAGIR: + """ + ``DAGIR`` is the main data structure used in the EVT Intermediate Representation. + It consists of a series of ``Node`` s, each representing epilogue visitor nodes. + + In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node + """ + def __init__(self, cc, element_compute=DataType.f32) -> None: + # The EVT DAGIR is managed through the nextworkX Digraph class + self._graph = nx.DiGraph() + + self.element_compute = element_compute + + self.reduction_names = [] + + self.cc = cc + + self.identity_counter = 0 + + # + # IR manipulator + # + + def add_node(self, meta: NodeBase): + """ + Add a node to dag ir + """ + if self.has_node(meta.name): + raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.") + self._graph.add_node(meta.name, meta=meta) + + def add_edge(self, src: str, dst: str, weight: int=0): + """ + Add an edge src -> dst to dag ir with weight + """ + if not self.has_node(src): + raise SyntaxError(f"Variable '{src}' is undefined.") + if not self.has_node(dst): + raise SyntaxError(f"Variable '{dst}' is undefined.") + + if self._graph.has_edge(src, dst): + # The DiGraph doesn't support multiple edges between two nodes + # We insert an identity node in such case as a workaround + identity_name = f"autogen_identity_{self.identity_counter}" + self.identity_counter += 1 + compute_node = ComputeNode( + name=identity_name, fn=ActivationOp.Identity, + element_output=self.element_compute, + element_compute=self.element_compute) + self.add_node(compute_node) + self.add_edge(src, identity_name, 0) + self.add_edge(identity_name, dst, weight) + else: + self._graph.add_edge(src, dst, weight=weight) + + def remove_node(self, node: str): + """ + Remove node from dag ir + """ + self._graph.remove_node(node) + + def remove_edge(self, src: str, dst: str): + """ + Remove edge src -> dst + """ + self._graph.remove_edge(src, dst) + + # + # Helper functions for getting attrs + # + + def has_node(self, node: str) -> bool: + """ + Check if the node is in the graph + """ + return self._graph.has_node(node) + + def in_degree(self, node: str): + """ + Get the input degree of node + """ + return self._graph.in_degree(node) + + def in_edges(self, node: str): + """ + Get the input edges of node + """ + return [edge for edge in self._graph.in_edges(node)] + + def out_degree(self, node: str): + """ + Get the output degree of node + """ + return self._graph.out_degree(node) + + def out_edges(self, node: str): + """ + Get the output edges of node + """ + return [edge for edge in self._graph.out_edges(node)] + + def get_node_meta(self, node: str): + """ + Get the meta data of the node + """ + return self._graph.nodes[node]["meta"] + + def get_edge_weight(self, src, dst): + """ + Get the edge weight of edge src->dst + """ + return self._graph.get_edge_data(src, dst)["weight"] + + # + # High-level helper functions + # + + def all_reachable_nodes(self, node: str): + """ + Get all the nodes reachable from the current node (exclude) + """ + return list(nx.dfs_preorder_nodes(self._graph, source=node)) + + def get_users(self, node: str): + """ + Get all users of the current node + """ + return [edge[1] for edge in self.out_edges(node)] + + def get_all_inputs(self, node: str): + """ + Get all the input nodes sorted by edge weight + """ + in_edges = self.in_edges(node) + edge_weights = [self.get_edge_weight(*edge) for edge in in_edges] + return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))] + + def get_all_inputs_meta(self, node: str): + """ + Get all the input node metas sorted by edge weight + """ + return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)] + + def replace_all_uses_with(self, node1, node2): + """ + Replace all uses of node1 with node2 + """ + for edge in self.out_edges(node1): + weight = self.get_edge_weight(*edge) + user = edge[1] + self.add_edge(node2, user, weight) + self.remove_edge(node1, user) + self.remove_node(node1) + + # + # Node accessor + # + def nodes_topological_order(self): + """ + Get the nodes in the unique lexicographical topological order + It generates a unique ordering of nodes by first sorting topologically + and then additionally by sorting lexicographically. + + Although topological_sort alone also works, this generates a unique key + for each epilogue visitor pattern and ensures the compilation cache can be reused. + :return: list[str] + """ + return list(nx.lexicographical_topological_sort(self._graph)) + + def node_metas_topological_order(self): + """ + Get the node metas in topological order + :return: list[NodeBase] + """ + return [self.get_node_meta(node) for node in self.nodes_topological_order()] + + @property + def nodes(self): + """ + Get all nodes + :return: list[str] + """ + return list(self._graph.nodes) + + @property + def nodes_meta(self): + """ + Get all node metas + :return: list[NodeBase] + """ + return [data[1]['meta'] for data in self._graph.nodes.data()] + + @property + def edges(self): + """ + Get all edges + :return: list[(str, str)] + """ + return list(self._graph.edges) + + # + # Path + # + def has_path(self, src: str, target: str) -> bool: + """ + Return True is a path exists from src to target + """ + return nx.has_path(self._graph, src, target) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..9d453b1f4c41d002297c5348cbed8fd7f0ef3081 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py @@ -0,0 +1,324 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Layout algebras +""" + +from pycute import Layout, composition, make_layout, flatten, product + + +def _infer_split(old_shape, new_shape): + old_shape = _tuple_to_list(old_shape) + new_shape = _tuple_to_list(new_shape) + if len(old_shape) == 0 and len(new_shape) == 0: + return [] + if len(old_shape) == 0: + if product(tuple(new_shape)) != 1: + raise ValueError("Invalid reshape size") + else: + return new_shape + if len(new_shape) == 0: + if product(tuple(old_shape)) != 1: + raise ValueError("Invalid reshape size") + else: + return old_shape + # This is done recursively by only process the last dimension at each time + old_dim = old_shape[-1] + new_dim = new_shape[-1] + # Exact match + if old_dim == new_dim: + return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,] + # Needs split + if old_dim > new_dim and old_dim % new_dim == 0: + residual = old_dim // new_dim + return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,] + # Needs merge + if old_dim < new_dim and new_dim % old_dim == 0: + residual = new_dim // old_dim + return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,] + + raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}") + +def _infer_merge(flatten_shape, shape): + flatten_shape = _tuple_to_list(flatten_shape) + shape = _tuple_to_list(shape) + idx_flat = 0 + merged_shape = [] + for dim in shape: + # Exact match + if dim == flatten_shape[idx_flat]: + merged_shape.append(dim) + idx_flat += 1 + # Need group + elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0: + residual = dim + group = [] + while(residual > 1): + group.append(flatten_shape[idx_flat]) + residual = residual // flatten_shape[idx_flat] + idx_flat += 1 + merged_shape.append(group) + else: + raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}") + + return merged_shape + +def _list_to_tuple(nested_list): + if isinstance(nested_list, list) or isinstance(nested_list, tuple): + return tuple(_list_to_tuple(item) for item in nested_list) + return nested_list + +def _tuple_to_list(nested_tuple): + if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple): + return list(_tuple_to_list(item) for item in nested_tuple) + return nested_tuple + +def _reverse_tuple(nested_tuple: tuple): + if isinstance(nested_tuple, tuple): + return tuple([_reverse_tuple(item) for item in nested_tuple][::-1]) + return nested_tuple + +def _get_first_lhs_nonzero_stride(stride_list, idx): + for i in reversed(range(idx)): + if stride_list[i] != 0: + return i + else: + return None + +def _get_first_rhs_nonzero_stride(stride_list, idx): + for i in range(idx+1, len(stride_list)): + if stride_list[i] != 0: + return i + else: + return None + +def reshape(layout, new_shape): + """ + General reshape of input layout. + It takes two steps: + 1. split the dimensions of the old layout + 2. merge the splitted dimensions according to the new shape + """ + # + # Step 1: Split the dimensions of the old layout + # + # 1.1 Flat old and new shape + old_flatten_shape = list(flatten(layout.shape)) + new_flatten_shape = list(flatten(new_shape)) + + # 1.2 Infer the flatten splitted shape + splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape) + + # 1.3 Unflat the splitted shape based on the old shape + splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape) + + # 1.4 Infer the type of each split + # If the split type is in row-major (R), the dimension list is reversed because + # the cute::composition only support column-major split + split_type = [] # the type of each split (ColumnMajor or RowMajor) + permuted_splitted_shape = [] + old_flatten_stride = list(flatten(layout.stride)) + for idx, dim in enumerate(splited_shape): + if not isinstance(dim, list): + permuted_splitted_shape.append(dim) + split_type.append("C") + else: + lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx) + rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx) + # Special case for single tuple + # Use column-major by default + if lhs_stride is None and rhs_stride is None: + permuted_splitted_shape.append(dim) + split_type.append("C") + else: + if lhs_stride is not None and rhs_stride is not None: + # We consider shape[idx]:stride[idx] + # Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major + if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride: + permuted_splitted_shape.append(dim) + split_type.append("C") + # Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major + elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride: + permuted_splitted_shape.append([d for d in reversed(dim)]) + split_type.append("R") + # Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave + elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride: + if lhs_stride >= rhs_stride: + permuted_splitted_shape.append(dim) + split_type.append("C") + else: + permuted_splitted_shape.append([d for d in reversed(dim)]) + split_type.append("R") + # Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave + elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride: + if lhs_stride >= rhs_stride: + permuted_splitted_shape.append(dim) + split_type.append("C") + else: + permuted_splitted_shape.append([d for d in reversed(dim)]) + split_type.append("R") + else: + raise NotImplementedError() + elif lhs_stride is None: + # Case 1: dim's stride < dim+1's stride, expand in column major + if old_flatten_stride[idx] > rhs_stride: + permuted_splitted_shape.append([d for d in reversed(dim)]) + split_type.append("R") + else: + permuted_splitted_shape.append(dim) + split_type.append("C") + else: + # Case 1: dim's stride > dim-1's stride + if old_flatten_stride[idx] < lhs_stride: + permuted_splitted_shape.append([d for d in reversed(dim)]) + split_type.append("R") + else: + permuted_splitted_shape.append(dim) + split_type.append("C") + + # 1.4 Generate the splitted layout + permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape))) + + # 1.5 Reverse the permutation in 1.4 before merge + splitted_shape = [] + splitted_stride = [] + for shape_dim, stride_dim, type in zip( + permuted_splitted_layout.shape, + permuted_splitted_layout.stride, + split_type): + if type == "C": + splitted_shape.append(shape_dim) + splitted_stride.append(stride_dim) + else: + splitted_shape.append(tuple([d for d in reversed(shape_dim)])) + splitted_stride.append(tuple([d for d in reversed(stride_dim)])) + splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride)) + + + # + # Step 2: Merge the splitted dimensions according to the new shape + # + # 2.1 Merge layout + merged_layout = composition(splitted_layout, Layout(new_shape)) + + # 2.2 Cleaning up + output_layout = composition(merged_layout, Layout(new_shape)) + return output_layout + + +def permutation(layout, permutation): + """ + Permute the layout + """ + new_shape = tuple([layout.shape[idx] for idx in permutation]) + new_stride = tuple([layout.stride[idx] for idx in permutation]) + return Layout(new_shape, new_stride) + + +def _broadcast(layout, new_shape): + if len(layout) == 1 and isinstance(new_shape, int): + old_dim = layout.shape + old_stride = layout.stride + new_dim = new_shape + if old_dim == new_dim: + return Layout(old_dim, old_stride) + elif old_dim == 1: + return Layout(new_dim, 0) + else: + raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}") + + # Align the dimensions + old_shape = layout.shape + if isinstance(old_shape, int): + old_shape = (old_shape,) + sub_layouts = [layout,] + else: + sub_layouts = [sub_layout for sub_layout in layout] + rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape)) + # Get the broadcasted layout + broadcast_layouts = [] + try: + layout = make_layout(*sub_layouts, *rhs_broadcast_layouts) + broadcast_layouts = [] + for idx, sub_layout in enumerate(layout): + broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx])) + except NotImplementedError: + layout = make_layout(*rhs_broadcast_layouts, *sub_layouts) + for idx, sub_layout in enumerate(layout): + broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx])) + return make_layout(*broadcast_layouts) + + +def broadcast(layout, new_shape): + """ + Broadcast the new layout based on the input shape + The broadcasted shape equals to the new shape + The stride of broadcasted dimensions are 0 + """ + return _broadcast(layout, new_shape) + + +def debroadcast(layout, dims): + """ + Squeeze the 0-stride + """ + for dim in dims: + if layout.stride[dim] != 0: + raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}") + new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims]) + new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims]) + return Layout(new_shape, new_stride) + + +def canonicalization_(shapes, strides): + if isinstance(shapes, tuple): + c_shapes = [] + c_strides = [] + for shape, stride in zip(shapes, strides): + c_shape, c_stride = canonicalization_(shape, stride) + c_shapes.append(c_shape) + c_strides.append(c_stride) + return tuple(c_shapes), tuple(c_strides) + else: + if shapes == 1: + return 1, 0 + else: + return shapes, strides + +def canonicalization(layout): + """ + Canonicalize the input layout + 1. set the stride of shape "1" to 0 + """ + new_shape, new_stride = canonicalization_(layout.shape, layout.stride) + return Layout(new_shape, new_stride) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..1095e2ab1d956399b5e27ddaf140e53d9918ec26 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py @@ -0,0 +1,336 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Layout manipulation nodes and implementations + +The layout Nodes change the layout of intermediate nodes in epilogue visitor graph +""" + +from copy import deepcopy + +from cutlass_library import LayoutType +from pycute import product, flatten + +import cutlass_cppgen +from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list +from cutlass_cppgen.backend.evt.ir.node import NodeBase +from cutlass_cppgen.backend.evt.ir.tensor import Tensor + + +class PermutationImpl: + """ + Detailed implementation and helper functions for permutation + """ + def __init__(self, node) -> None: + assert "indices" in node.kwargs.keys() + self.indices = list(node.kwargs["indices"]) + self.inverse_indices = self.get_inverse_indices(self.indices) + + def get_inverse_impl(self): + inverse_impl = deepcopy(self) + inverse_impl.indices = self.inverse_indices + inverse_impl.inverse_indices = self.indices + return inverse_impl + + def update(self, shape): + num_dim = len(shape) + indices = self.indices + num_old_dim = len(indices) + # Add offset + for i, idx in enumerate(indices): + indices[i] = idx + num_dim - num_old_dim + # Add broadcast dims + for i in range(num_dim - num_old_dim): + indices = [i,] + indices + + self.indices = indices + self.inverse_indices = self.get_inverse_indices(self.indices) + + def get_inverse_indices(self, indices): + """ + Get the indices for inverse permutation + """ + num_dim = len(indices) + inverse_indices = [0] * num_dim + for i in range(num_dim): + inverse_indices[indices[i]] = i + return inverse_indices + + def shape_propagation(self, input_node_meta): + input_shape = input_node_meta.tensor.shape + output_shape = tuple([input_shape[idx] for idx in self.indices]) + return output_shape + + def broadcast(self, shape, node_meta: NodeBase): + """ + Broadcast the inputs based on current shape + """ + self.update(shape) + inverse_shape = tuple([shape[idx] for idx in self.inverse_indices]) + node_meta.tensor.broadcast(inverse_shape) + + def apply_to_user(self, usr_meta: NodeBase): + """ + Propagate the permutation to the users of the current nodes + """ + usr_meta.tensor.permute(self.inverse_indices) + if hasattr(usr_meta, "store_tensor"): + if usr_meta.store_tensor is not None: + usr_meta.store_tensor.permute(self.inverse_indices) + + def apply_to_input(self, input_meta: NodeBase): + """ + Propagate the permutation to inputs of the current nodes + """ + input_meta.tensor.permute(self.indices) + if hasattr(input_meta, "store_tensor"): + if input_meta.store_tensor is not None: + input_meta.store_tensor.permute(self.indices) + + +class ReshapeImpl: + """ + Detailed implementation and helper functions for reshape + """ + def __init__(self, node) -> None: + self.node = node + assert "new_shape" in node.kwargs.keys() + self.output_shape = _list_to_tuple(node.kwargs["new_shape"]) + + def get_inverse_impl(self): + inverse_impl = deepcopy(self) + inverse_impl.output_shape = self.input_shape + inverse_impl.input_shape = self.output_shape + return inverse_impl + + def shape_propagation(self, input_node_meta): + self.input_shape = input_node_meta.tensor.shape + return _list_to_tuple(self.output_shape) + + def broadcast(self, shape, node_meta: NodeBase): + """ + Broadcast the inputs based on current shape. + """ + # Step 1: infer split + flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape)) + split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape) + split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape) + + # broadcast shape -> split_output_shape -> flatten_split_shape + if len(shape) - len(split_output_shape) > 0: + for _ in range(len(shape) - len(split_output_shape)): + split_output_shape = [1,] + split_output_shape + flatten_split_shape = [1,] + flatten_split_shape + split_input_shape = [1,] + split_input_shape + broadcast_factor = [] + for dim, old_dim in zip(shape, split_output_shape): + if not isinstance(dim, list): + dim = [dim,] + if not isinstance(old_dim, list): + old_dim = [old_dim,] + if product(tuple(dim)) == product(tuple(old_dim)): + broadcast_factor += [1] * len(old_dim) + elif product(tuple(old_dim)) == 1: + assert len(dim) == 1 + broadcast_factor.append(dim[0]) + else: + raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}") + + # flatten_split_shape -> split_input_shape + factor_idx = 0 + broadcast_split_input_shape = [] + for dim in split_input_shape: + if isinstance(dim, list): + new_dim = [] + for d in dim: + new_dim.append(d * broadcast_factor[factor_idx]) + factor_idx += 1 + broadcast_split_input_shape.append(new_dim) + else: + broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx]) + factor_idx += 1 + broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape) + node_meta.tensor.reshape(_list_to_tuple(split_input_shape)) + node_meta.tensor.broadcast(broadcast_split_input_shape) + # Last reshape op to clean up + broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape]) + node_meta.tensor.reshape(broadcast_input_shape) + # Update the input shape and output shape + self.input_shape = _list_to_tuple(node_meta.tensor.shape) + self.output_shape = _list_to_tuple(shape) + + def apply_to_user(self, user_meta: NodeBase): + """ + Propagate the reshape to user nodes + """ + user_meta.tensor.reshape(tuple(self.input_shape)) + if hasattr(user_meta, "store_tensor"): + if user_meta.store_tensor is not None: + user_meta.store_tensor.reshape(tuple(self.input_shape)) + + def apply_to_input(self, input_meta: NodeBase): + """ + Propagate the reshape to input nodes + """ + input_meta.tensor.reshape(tuple(self.output_shape)) + if hasattr(input_meta, "store_tensor"): + if input_meta.store_tensor is not None: + input_meta.store_tensor.reshape(tuple(self.output_shape)) + + # + # Helper functions + # + + def infer_split(self, input_shape, output_shape): + """ + Infer the flatten splitted shape that can be merged to both input_shape and output_shape + """ + input_shape = _tuple_to_list(input_shape) + output_shape = _tuple_to_list(output_shape) + if len(input_shape) == 0 and len(output_shape) == 0: + return [] + if len(input_shape) == 0: + if product(tuple(output_shape)) != 1: + raise ValueError("Invalid reshape size") + else: + return output_shape + if len(output_shape) == 0: + if product(tuple(input_shape)) != 1: + raise ValueError("Invalid reshape size") + else: + return input_shape + # This is done recursively by only process the last dimension at each time + old_dim = input_shape[-1] + new_dim = output_shape[-1] + # Exact match + if old_dim == new_dim: + return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,] + # Needs split + if old_dim > new_dim and old_dim % new_dim == 0: + residual = old_dim // new_dim + return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,] + # Needs merge + if old_dim < new_dim and new_dim % old_dim == 0: + residual = new_dim // old_dim + return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,] + + raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}") + + def infer_merge(self, flatten_shape, shape): + flatten_shape = _tuple_to_list(flatten_shape) + shape = _tuple_to_list(shape) + idx_flat = len(flatten_shape) - 1 + merged_shape = [] + for dim in reversed(shape): + # Exact match + if dim == flatten_shape[idx_flat]: + merged_shape.append(dim) + idx_flat -= 1 + # need group + elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0: + residual = dim + group = [] + while(residual > 1): + group.append(flatten_shape[idx_flat]) + residual = residual // flatten_shape[idx_flat] + idx_flat -= 1 + merged_shape.append(group[::-1]) + else: + raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}") + + return merged_shape[::-1] + + +class LayoutNode(NodeBase): + """ + Layout manipulation nodes + """ + fn_to_impl = { + "permute": PermutationImpl, + "reshape": ReshapeImpl + } + def __init__(self, name: str, fn, kwargs: dict) -> None: + super().__init__(name) + self.op = "layout" + self.fn = fn + self.kwargs = kwargs + self.underlying_impl = self.fn_to_impl[self.fn.__name__](self) + + def get_inverse_node(self): + inverse_node = deepcopy(self) + inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl() + return inverse_node + + def shape_propagation(self, input_node_metas): + if self._tensor is not None: + return + assert len(input_node_metas) == 1, "Layout node can only have one input node" + + output_shape = self.underlying_impl.shape_propagation(input_node_metas[0]) + + self._tensor = Tensor( + element=self.element_output, + shape=output_shape, layout_tag=LayoutType.RowMajor + ) + + return super().shape_propagation(input_node_metas) + + def type_propagation(self, input_node_metas: 'list[NodeBase]'): + """ + The store nodes has element_output = element_input + """ + assert len(input_node_metas) == 1, "Layout node can only have one input node" + self.element_output = input_node_metas[0].element_output + + def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): + """ + Propagate the broadcast in the reversed topological order + """ + if self.tensor is None: + raise RuntimeError(f"The tensor of node {self.name} is unknown.") + shape = self.tensor.shape + + for child in input_node_metas: + self.underlying_impl.broadcast(shape, child) + + def apply_to_user(self, usr_meta: NodeBase): + """ + Propagate the permutation to user nodes + """ + self.underlying_impl.apply_to_user(usr_meta) + + def apply_to_input(self, input_meta: NodeBase): + """ + Propagate the permutation to input nodes + """ + self.underlying_impl.apply_to_input(input_meta) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..bff0aaa2c21ef2545f50745cdb33499270eeb9fb --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py @@ -0,0 +1,294 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Load nodes and implementations +""" + +import ctypes + +from cutlass_cppgen.backend.c_types import tuple_factory +from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value +from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase + + +class LoadImplBase(ImplBase): + """ + Base class for load node implementations + """ + reserved_names = ["accum", "C"] + def __init__(self, node) -> None: + super().__init__(node) + self.element = node.element + self.element_output = node.element_output + self.stride = node.tensor.stride + + +class AccumulatorImpl(LoadImplBase): + """ + Accumulator node implementation + """ + + @staticmethod + def match(node, problem_size: tuple): + return node.name == "accum" and node.tensor.shape == problem_size + + +class LoadSrcImpl(LoadImplBase): + """ + Load C implementation + """ + @property + def name_camel(self) -> str: + return "TensorC" + + @property + def argument_type_c(self): + stride_mnl = self.get_stride_mnl() + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_C", ctypes.c_void_p), + ("stride_C", tuple_type) + ] + def __init__(self, ptr) -> None: + self.ptr_C = ptr + self.stride_C = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + return node.name == "C" and node.tensor.shape == problem_size + + +class AuxLoadImpl(LoadImplBase): + """ + Load arbitrary tensor + """ + @property + def argument_type(self): + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + element_type = self.element + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_aux", ctypes.c_void_p), + ("null_default", dtype2ctype[element_type]), + ("dAux", tuple_type) + ] + def __init__(self, kwargs) -> None: + ptr = kwargs[name] + self.ptr_aux = ptr + self.null_default = to_ctype_value(0, element_type) + self.dAux = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if node.name in LoadImplBase.reserved_names: + return False + strideMN = node.tensor.stride[-2:] + if (strideMN[0] == 1 and strideMN[1] != 0 or + strideMN[0] != 0 and strideMN[1] == 1 ): + return True + else: + return False + + +class RowBroadcastImpl(LoadImplBase): + """ + Broadcast a row vector + """ + def __init__(self, node) -> None: + super().__init__(node) + self.stride_dtype = "int" + + @property + def argument_type(self): + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + element_type = self.element + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_row", ctypes.c_void_p), + ("null_default", dtype2ctype[element_type]), + ("dRow", tuple_type) + ] + def __init__(self, kwargs) -> None: + ptr = kwargs[name] + self.ptr_row = ptr + self.null_default = to_ctype_value(0, element_type) + self.dRow = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if node.name in LoadImplBase.reserved_names: + return False + + strideMN = node.tensor.stride[-2:] + if strideMN == (0, 1): + return True + else: + return False + + +class ColumnBroadcastImpl(LoadImplBase): + """ + Broadcast a column vector + """ + def __init__(self, node) -> None: + super().__init__(node) + self.stride_dtype = "int" + + @property + def argument_type(self): + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + element_type = self.element + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_col", ctypes.c_void_p), + ("null_default", dtype2ctype[element_type]), + ("dCol", tuple_type) + ] + def __init__(self, kwargs) -> None: + ptr = kwargs[name] + self.ptr_col = int(ptr) + self.null_default = to_ctype_value(0, element_type) + self.dCol = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if node.name in LoadImplBase.reserved_names: + return False + + strideMN = node.tensor.stride[-2:] + if strideMN == (1, 0): + return True + else: + return False + + +class ScalarBroadcastImpl(LoadImplBase): + """ + Broadcast a scalar + """ + def __init__(self, node) -> None: + super().__init__(node) + self.stride_dtype = "int" + + @property + def argument_type(self): + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + element_type = self.element + + if self.tensor.is_constant: + value = self.tensor.value + class _Argument(ctypes.Structure): + _fields_ = [ + ("scalars", dtype2ctype[element_type]), + ("scalar_ptrs", ctypes.c_void_p), + ("dScalar", tuple_type) + ] + def __init__(self, kwargs) -> None: + self.scalars = to_ctype_value(value, element_type) + self.scalar_ptrs = 0 + self.dScalar = tuple_type(stride_mnl) + + else: + class _Argument(ctypes.Structure): + _fields_ = [ + ("scalars", dtype2ctype[element_type]), + ("scalar_ptrs", ctypes.c_void_p), + ("dScalar", tuple_type) + ] + def __init__(self, kwargs) -> None: + scalar_or_ptr = kwargs[name] + if isinstance(scalar_or_ptr, float): + self.scalars = to_ctype_value(scalar_or_ptr, element_type) + self.scalar_ptrs = 0 + else: + self.scalar_ptrs = int(scalar_or_ptr) + + self.dScalar = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if node.name in LoadImplBase.reserved_names: + return False + + strideMN = node.tensor.stride[-2:] + if strideMN == (0, 0): + return True + else: + return False + + +class LoadNode(NodeBase): + """ + Load Node + """ + cnt = 0 + possible_impls = [ + AccumulatorImpl, LoadSrcImpl, AuxLoadImpl, + RowBroadcastImpl, ColumnBroadcastImpl, + ScalarBroadcastImpl + ] + def __init__(self, name: str) -> None: + if name is None: + name = f"load{LoadNode.cnt}" + LoadNode.cnt += 1 + super().__init__(name) + self.op = "load" + + def type_propagation(self, *args, **kwargs): + """ + Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`. + """ + if self.tensor is None: + raise RuntimeError(f"The tensor of node {self.name} is unknown.") + + self.element = self.tensor.element + self.element_output = self.tensor.element diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py new file mode 100644 index 0000000000000000000000000000000000000000..606591b8e78c97114b85b329050d630d55460d7a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py @@ -0,0 +1,306 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Base & visitor classes of DAGIR Nodes +""" + +import ctypes +from re import sub + +from cutlass_library import LayoutType + +from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple +from cutlass_cppgen.backend.evt.ir.tensor import Tensor + + +class TupleEmitter: + """ + Emit the cute tuple to C++ code + """ + def __init__(self, stride_dtype): + self.stride_dtype = stride_dtype + + def emit(self, py_tuple): + if isinstance(py_tuple, int): + if py_tuple in [0, 1]: + return f"cute::Int<{py_tuple}>" + else: + return f"{self.stride_dtype}" + elif isinstance(py_tuple, tuple): + decl = "cute::Stride<" + for item in py_tuple: + decl += self.emit(item) + ", " + return decl[:-2] + ">" + else: + raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}") + + +class ImplBase: + """ + Base class for Node Implementation + """ + def __init__(self, node) -> None: + self.node = node + self.name = node.name + self.tensor = node.tensor + self._type_decl = None + self.tuple_emitter = TupleEmitter("int64_t") + + @property + def stride_dtype(self): + return self.tuple_emitter.stride_dtype + + @stride_dtype.setter + def stride_dtype(self, stride_dtype): + self.tuple_emitter.stride_dtype = stride_dtype + + @staticmethod + def match(node, problem_size: tuple): + """ + Match function used in get_underlying_impl + """ + raise NotImplementedError(f"The `match` function is not defined.") + + @property + def argument_type(self): + """ + Default class for Argument Type + """ + class _Argument(ctypes.Structure): + _fields_ = [] + + def __init__(self, *args, **kwargs) -> None: + pass + + return _Argument + + @property + def name_camel(self) -> str: + """ + Return the CamelCase name. + """ + return sub(r"(_|-)+", " ", self.name).title().replace(" ", "") + + @property + def stride_mnl(self): + """ + Typename StrideMNL + """ + stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2])))) + return self.tuple_emitter.emit(stride) + + def get_non_constant_stride(self, py_tuple): + if isinstance(py_tuple, int): + if py_tuple not in [0, 1]: + return py_tuple + else: + return None + non_constant_stride = [] + for item in py_tuple: + item_out = self.get_non_constant_stride(item) + if item_out: + non_constant_stride.append(item_out) + return tuple(non_constant_stride) + + def get_stride_mnl(self): + """ + Get the non-zero stride mnl. This is used in argument construction + """ + stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2])))) + return stride + + def get_smem_size(self, *args, **kwargs): + """ + Get the shared memory size and alignment of current node + """ + return (0, 1) + + +class NoOpImpl(ImplBase): + """ + The NoOpImpl does nothing but forward its input to users + """ + def __init__(self, node) -> None: + super().__init__(node) + + @staticmethod + def match(node, problem_size: tuple): + if node.op == "store": + # Store that is not output is a No OP + return not node.is_output + + +class NodeBase: + """ + Base class of DAG Node + """ + def __init__(self, name: str) -> None: + self.name = name + self.underlying_impl = None + + self._tensor = None + + # Whether the node is disabled for emit + self.disabled = False + + @property + def name_camel(self) -> str: + """ + Return the CamelCase name. + """ + return self.underlying_impl.name_camel + + @property + def tensor(self) -> Tensor: + """ + Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor) + """ + return self._tensor + + @tensor.setter + def tensor(self, kwargs): + """ + Setting the tensor + """ + self._tensor = Tensor(**kwargs) + + # + # Helper functions for type/shape propagation + # + + def shape_propagation(self, input_node_metas): + """ + Infer shape from input nodes + General Broadcasting Rules from NumPy + When operating on two arrays, we compare their shapes element-wise. + It starts with the trailing (i.e. rightmost) dimension and works its + way left. Two dimensions are compatible when + 1. they are equal + 2. one of them is 1 + """ + if self._tensor is not None: + return + + shape = None + for src in input_node_metas: + src_shape = src.tensor.shape + if shape is None: + shape = src_shape + else: + len_difference = len(shape) - len(src_shape) + if len_difference > 0: + for _ in range(len_difference): + src_shape = [1, ] + list(src_shape) + elif len_difference < 0: + for _ in range(-len_difference): + shape = [1, ] + list(shape) + broadcasted_shape = [] + # Infer broadcast shape + for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)): + if shape_dim == 1: + broadcasted_shape = [src_dim, ] + list(broadcasted_shape) + elif src_dim == 1: + broadcasted_shape = [shape_dim, ] + list(broadcasted_shape) + elif shape_dim == src_dim: + broadcasted_shape = [shape_dim, ] + list(broadcasted_shape) + else: + error_msg = "Dimension mismatch between " + for src_ in input_node_metas: + error_msg += f"{src_.name}{src_.tensor.shape}, " + error_msg = error_msg[:-2] + "." + raise RuntimeError(error_msg) + shape = tuple(broadcasted_shape) + + self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor) + + def type_propagation(self, *args, **kwargs): + """ + Each node is associated with two data types: `element` and `element_output`. + The `element_output` is the type of return array of the node. The `element` + has specific meaning for different node types. + * Load Node: data type of tensor in gmem + * Compute Node: element compute + * Store Node: data type of tensor in gmem + This function must be overloaded in the derived classes + """ + raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}") + + def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): + """ + Propagate the broadcast in the reversed topological order. + For example: + C[l, m, n] = A[m, 1] + B[l, m, n] + After the broadcast propagation, it will be come + C[l, m, n] = A[l, m, n] + B[l, m, n] + and each tensor will have a proper stride accessing the underlying tensor + """ + if self.tensor is None: + raise RuntimeError(f"The tensor of node {self.name} is unknown.") + for child in input_node_metas: + child.tensor.broadcast(self.tensor.shape) + + def get_underlying_impl(self, problem_size: tuple): + """ + Get the underlying implementation of the current node. + """ + if self.tensor is None: + raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.") + + for impl in self.possible_impls: + if impl.match(self, problem_size): + self.underlying_impl = impl(self) + break + + if self.underlying_impl is None: + raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.") + +# +# Visitor Nodes & Impls +# + +class TopoVisitorImpl(ImplBase): + """ + Impl for topological visitor + """ + def __init__(self, node) -> None: + super().__init__(node.output_node) + self.name = node.name + self.element_output = node.output_node.element_output + +class TopoVisitorNode(NodeBase): + def __init__(self, name: str, subgraph, output_node) -> None: + super().__init__(name) + self.subgraph = subgraph + self.output_node = output_node + self.op = "dag" + self.underlying_impl = TopoVisitorImpl(self) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..708405e0647ca3cb22bd0c1d4770d71810a469e2 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py @@ -0,0 +1,277 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Store node and implementations +""" + +import ctypes + +from cutlass_library import DataType + +from cutlass_cppgen.backend.c_types import tuple_factory +from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value +from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl +from cutlass_cppgen.backend.evt.ir.tensor import Tensor +from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp + + +class StoreImplBase(ImplBase): + """ + Base class for store node implementation + """ + reserved_names = ["D"] + def __init__(self, node) -> None: + super().__init__(node) + self.element = node.element + self.element_output = node.element_output + self.stride = node.store_tensor.stride + + +class StoreDImpl(StoreImplBase): + """ + Store D implementation + """ + + @property + def argument_type_d(self): + stride_mnl = self.get_stride_mnl() + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_D", ctypes.c_void_p), + ("stride_D", tuple_type) + ] + def __init__(self, ptr: int) -> None: + self.ptr_D = ptr + self.stride_D = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if node.name == "D" and node.store_tensor.shape == problem_size: + return True + return False + + +class AuxStoreImpl(StoreImplBase): + def __init__(self, node) -> None: + super().__init__(node) + self.round_style = FloatRoundStyle.ToNearest + + @property + def argument_type(self): + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_aux", ctypes.c_void_p), + ("dAux", tuple_type) + ] + def __init__(self, kwargs) -> None: + ptr = kwargs[name] + self.ptr_aux = ptr + self.dAux = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if not node.is_output: + return False + if node.name in StoreImplBase.reserved_names: + return False + + strideMN = node.store_tensor.stride[-2:] + if (strideMN[0] == 1 and strideMN[1] != 0 or + strideMN[0] != 0 and strideMN[1] == 1 ): + return True + else: + return False + + +class ReductionImplBase(StoreImplBase): + def __init__(self, node) -> None: + super().__init__(node) + self.element = node.store_tensor.element + self.element_compute = node.element_compute + self.reg_reduce_fn = self.node.reg_reduce_fn + self.gmem_reduce_fn = self.node.gmem_reduce_fn + self.round_style = node.round_style + self.stride_dtype = "int" + + def get_reduce_identity(self): + """ + Return the reduction identity of the current reduce_fn + """ + maxes = { + DataType.f32: (2 ** 31) - 1, + DataType.f16: (2 ** 15), + DataType.s32: (2 ** 31) - 1, + DataType.s8: (2 ** 7) - 1 + } + mins = { + DataType.f32: -maxes[DataType.f32], + DataType.f16: -maxes[DataType.f16], + DataType.s32: -maxes[DataType.s32], + DataType.s8: -maxes[DataType.s8] + } + if self.reg_reduce_fn == FunctionalOp.Maximum: + if self.element_compute not in mins: + raise Exception(f"No min entry for data type {self.element_compute}") + return to_ctype_value(mins[self.element_compute], self.element_compute) + elif self.reg_reduce_fn == FunctionalOp.Multiplies: + return to_ctype_value(1., self.element_compute) + elif self.reg_reduce_fn == FunctionalOp.Minimum: + if self.element_compute not in maxes: + raise Exception(f"No max entry for data type {self.element_compute}") + return to_ctype_value(maxes[self.element_compute], self.element_compute) + else: + return to_ctype_value(0., self.element_compute) + + @property + def argument_type(self): + self.get_reduce_identity() + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + element_compute = self.element_compute + reduce_identity = self.get_reduce_identity() + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr", ctypes.c_void_p), + ("reduce_identity", dtype2ctype[element_compute]), + ("dMNL", tuple_type) + ] + def __init__(self, kwargs) -> None: + ptr = kwargs[name] + self.ptr = ptr + self.reduce_identity = reduce_identity + self.dMNL = tuple_type(stride_mnl) + + return _Argument + + +class ColumnReductionImpl(ReductionImplBase): + + @staticmethod + def match(node, problem_size: tuple): + if not node.is_output: + return False + if node.name in StoreImplBase.reserved_names: + return False + + strideMN = node.store_tensor.stride[-2:] + if strideMN == (1, 0): + return True + else: + return False + + +class RowReductionImpl(ReductionImplBase): + + @staticmethod + def match(node, problem_size: tuple): + if not node.is_output: + return False + if node.name in StoreImplBase.reserved_names: + return False + + strideMN = node.store_tensor.stride[-2:] + if strideMN == (0, 1): + return True + else: + return False + + +class ScalarReductionImpl(ReductionImplBase): + + @staticmethod + def match(node, problem_size: tuple): + if not node.is_output: + return False + if node.name in StoreImplBase.reserved_names: + return False + + strideMN = node.store_tensor.stride[-2:] + if strideMN == (0, 0): + return True + else: + return False + + +class StoreNode(NodeBase): + """ + Store node + """ + possible_impls = [ + AuxStoreImpl, RowReductionImpl, + ColumnReductionImpl, ScalarReductionImpl, + NoOpImpl, StoreDImpl + ] + def __init__(self, name: str) -> None: + super().__init__(name) + self.op = "store" + self.is_output = False + self._store_tensor = None + + @property + def store_tensor(self) -> Tensor: + """ + Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor) + """ + return self._store_tensor + + @store_tensor.setter + def store_tensor(self, kwargs): + """ + Setting the tensor + """ + self._store_tensor = Tensor(**kwargs) + + def type_propagation(self, input_node_metas: 'list[NodeBase]'): + """ + The store nodes has element_output = element_input + """ + if self.is_output: + if self.store_tensor is None: + raise RuntimeError(f"The store tensor of node {self.name} is unknown.") + self.element = self.store_tensor.element + assert len(input_node_metas) == 1, "Store node can only have one input node" + self.element_output = input_node_metas[0].element_output + + def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): + super().broadcast_propagation(input_node_metas) + if self.is_output: + self._store_tensor.broadcast(self.tensor.shape) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..1a28b7306a140d08bd1edebd3486990ea69b9344 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py @@ -0,0 +1,137 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +High-level class for tensor +""" + +from cutlass_library import LayoutType + +from cutlass_cppgen.backend.evt.ir.layout_algorithm import ( + Layout, + broadcast, + canonicalization, + permutation, + reshape, + _reverse_tuple +) +from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type + + +class Tensor: + """ + The tensor abstracts the data type + """ + def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None: + if element is not None and tensor is not None: + raise Exception(f"Must not specify both element and tensor") + elif shape is not None and tensor is not None: + raise Exception(f"Must not specify both shape and tensor") + elif layout_tag is not None and tensor is not None: + raise Exception(f"Must not specify both layout_tag and tensor") + elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) : + raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)") + elif stride is not None and tensor is not None: + raise Exception(f"Must not specify both stride and tensor") + elif stride is not None and layout_tag is not None: + raise Exception(f"Must not specify layout_tag when stride is provided") + + if isinstance(tensor, Tensor): + # Directly copy all the attributes + self.__dict__.update(vars(tensor)) + else: + if tensor is None: + self.element = library_type(element) + else: + self.element, layout_tag = get_datatype_and_layout(tensor) + shape = get_tensor_shape(tensor) + if stride is not None: + self.layout = Layout(shape[::-1], stride[::-1]) + else: + if layout_tag == LayoutType.RowMajor: + self.layout = Layout(shape[::-1]) + elif layout_tag == LayoutType.ColumnMajor: + self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))]) + self.layout = canonicalization(self.layout) + + self.is_constant = is_constant + # Save the tensor value if it is constant + if is_constant and tensor is not None: + self.value = tensor + + @property + def shape(self): + """ + Returns the RowMajor layout shape + """ + return _reverse_tuple(self.layout.shape) + + @property + def stride(self): + """ + Returns the RowMajor layout stride + """ + return _reverse_tuple(self.layout.stride) + + @property + def rank(self): + """ + Returns the rank of the tensor + """ + return len(self.shape) + + # + # Layout Algorithms + # + + def broadcast(self, shape): + """ + Broadcast self.layout to shape + """ + assert isinstance(shape, tuple) + self.layout = broadcast(self.layout, _reverse_tuple(shape)) + + def reshape(self, shape): + """ + Reshape self.layout to shape + """ + assert isinstance(shape, tuple) + reverse_shape = _reverse_tuple(shape) + self.layout = reshape(self.layout, reverse_shape) + + def permute(self, indices): + """ + Permute self.layout according to indices + """ + length = len(indices) + indices = [length - idx - 1 for idx in indices] + self.layout = permutation(self.layout, indices[::-1]) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..badc38d96a830992c94afa693ea4b56a8e404c96 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py @@ -0,0 +1,42 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer +from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType +from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree +from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl +from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD +from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager +from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py new file mode 100644 index 0000000000000000000000000000000000000000..8a28c6e4e62d1a7bd7431c81aac366b8788fd8df --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py @@ -0,0 +1,143 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from __future__ import annotations + +import subprocess + +from cutlass_library import DataTypeTag + +from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR + + +_COLOR_MAP = { + "load": '"AliceBlue"', + "compute": "LemonChiffon1", + "accumulator": "LightGrey", + "store": "PowderBlue", + "layout": "lightseagreen", + "dag": "darkorange" +} + + +class EVTGraphDrawer: + """ + Visualize a EVT DAGIR with graphviz + """ + def __init__( + self, + graph: DAGIR, + name: str + ): + self._name = name + self._dot_graphs = {} + + self._dot_graphs[name] = self._to_dot(graph, name) + + def _get_node_style(self, node): + template = { + "shape": "record", + "fillcolor": "#CAFFE3", + "style": '"filled,rounded"', + "fontcolor": "#000000", + } + if node.op in _COLOR_MAP: + template["fillcolor"] = _COLOR_MAP[node.op] + else: + raise NotImplementedError("unknown node op") + if node.disabled: + template["fontcolor"] = "grey" + template["fillcolor"] = "white" + return template + + def _get_node_label(self, node): + label = "{" + f"name={node.name}|op={node.op}" + if node.op == "layout": + label += f"|fn={node.fn.__name__}" + for key in node.kwargs: + label += f"|{key}={node.kwargs[key]}" + if node.underlying_impl is not None: + label += f"|impl={type(node.underlying_impl).__name__}" + if node.op == "load": + label += f"|element_output={DataTypeTag[node.underlying_impl.element]}" + elif node.op == "compute": + label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}" + elif node.op == "store": + label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}" + elif node.op == "dag": + label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}" + if node.tensor is not None: + shape = node.tensor.shape + stride = node.tensor.stride + label += f"|shape={shape}|stride={stride}" + + if hasattr(node, "store_tensor"): + if node.store_tensor is not None: + store_shape = node.store_tensor.shape + store_stride = node.store_tensor.stride + label += f"|store_shape={store_shape}|stride_stride={store_stride}" + + label += "}" + return label + + def _to_dot( + self, + graph: DAGIR, + name: str + ): + import pydot + dot_graph = pydot.Dot(name, randir="TB") + for node in graph.nodes_meta: + style = self._get_node_style(node) + label = self._get_node_label(node) + dot_node = pydot.Node( + node.name, label=label, **style + ) + dot_graph.add_node(dot_node) + if node.op == "dag": + dot_subgraph = self._to_dot(node.subgraph, name=node.name) + self._dot_graphs[node.name] = dot_subgraph + + # Add edges + for src, dst in graph.edges: + weight = graph.get_edge_weight(src, dst) + dot_graph.add_edge(pydot.Edge(src, dst, label=weight)) + + return dot_graph + + def get_dot_graph(self) -> pydot.Dot: + return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()] + + def get_dot_graph_by_name(self, name) -> pydot.Dot: + return self._dot_graphs[name] + + def get_main_dot_graph(self) -> pydot.Dot: + return self._dot_graphs[self._name] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c3cdbde6d46ad8a7e84c3b95422bdb55e877c5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py @@ -0,0 +1,120 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Construct the epilogue visitor argument type +""" + +from cutlass_cppgen.backend.c_types import visitor_factory +from cutlass_cppgen.backend.evt.ir import TopoVisitorNode +from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree +from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.passes.util import cc_map + + +class PassGetArgumentType(EVTPassBase): + """ + Construct the epilogue visitor argument type + """ + dependencies = [ + PassShapeTypePropagation, # The Layout of all nodes must be set + PassDAG2Tree, # The type of each node must be set + PassGetImpl # The DAG subgraphs must be set + ] + + def requires(self) -> None: + # Check "D" is in the node list + if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")): + raise SyntaxError( + "Sm90+ EVT requires the epilogue to have a returned tensor D, " + "but the variable 'D' is not found in the return values.") + + def call(self): + nodes = self.dag_ir.nodes_topological_order() + self.argument_types = {} + for node in nodes: + meta = self.dag_ir.get_node_meta(node) + if not meta.disabled: + self.argument_types[node] = meta.underlying_impl.argument_type + if node == "D" and cc_map[self.cc] in [90, 100]: + continue + if isinstance(meta, TopoVisitorNode): + self.get_dag_argument_type(node) + else: + self.get_evt_argument_type(node) + + self.cc_specific_method(self.set_argument_type)() + + def get_evt_argument_type(self, node): + # Sort the input nodes by edge weight + input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)] + if len(input_types) > 0: + self.argument_types[node] = visitor_factory( + input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,]) + + def get_dag_argument_type(self, node): + meta = self.dag_ir.get_node_meta(node) + subgraph = meta.subgraph + subgraph_nodes = subgraph.nodes_topological_order() + # Visit the unvisited nodes in subgraph + for n in subgraph_nodes: + m = subgraph.get_node_meta(n) + if m.disabled: + continue + else: + self.argument_types[n] = m.underlying_impl.argument_type + input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]] + if len(input_types) > 0: + self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1]) + + def set_argument_type(self): + pass + + def sm90_set_argument_type(self): + self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]] + # Get the tensorD argument type + self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d + + # Get the tensorC argument type + if self.dag_ir.has_node("C"): + self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c + else: + self.dag_ir.arg_c_type = self.dag_ir.arg_d_type + + def sm100_set_argument_type(self): + self.sm90_set_argument_type() + + def sm80_set_argument_type(self): + nodes = self.dag_ir.nodes_topological_order() + self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..469769664abdf757319949ab48b4e7d5e982f200 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py @@ -0,0 +1,169 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented +by the topological visitor, while the rest of the graph will be implemented with the tree visitor. +""" + +from copy import deepcopy + +from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode +from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation + + +class PassDAG2Tree(EVTPassBase): + """ + Convert the DAG IR to Tree by fusing subgraphs + """ + dependencies = [ + PassShapeTypePropagation, + PassGetImpl + ] + + def call(self): + # Step 1: find the nodes that have multiple parents + multi_parent_nodes = [] + + for node in self.dag_ir.nodes_topological_order(): + if self.dag_ir.out_degree(node) > 1: + multi_parent_nodes.append(node) + # Step 2: find the lowest common ancestor (LCA) of all its parents + for node in multi_parent_nodes: + # A multi-parent node could be already fused by the previous node + if not self.dag_ir.has_node(node): + continue + # A node uncovered by the previous fusions can have out degree change + # Case 1: it has <= 1 edges to the previously fused subgraph, no degree change + # Case 2: it has more than one edges to the previously fused subgraph, degree drops + if self.dag_ir.out_degree(node) <= 1: + continue + + # Otherwise, the node still + reachable_nodes = [] + # Complexity: O(Dout*N) + for parent in self.dag_ir.get_users(node): + reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent))) + # get the common reachable objects + common_items = set.intersection(*reachable_nodes) + node_to_fuse = set.union(*reachable_nodes).difference(common_items) + + lca = None + # If common ancestor exists, find the lowest one + if len(common_items) > 0: + topo_order = self.dag_ir.nodes_topological_order() + topo_idx = -1 + for item in common_items: + if lca is None: + lca = item + topo_idx = topo_order.index(item) + else: + if topo_idx > topo_order.index(item): + lca = item + topo_idx = topo_order.index(item) + else: + # there is no common ancestor for all the parents, we pack all the reachable + # nodes into a single DAG node as a fallback. The lca should be the input node of + # one of the output nodes with out_degree = 0 + potential_output_nodes = [] + for node in node_to_fuse: + if self.dag_ir.out_degree(node) == 0: + potential_output_nodes.append(node) + if len(potential_output_nodes) == 0: + raise RuntimeError(f"No output node with out degree = 0 found.") + + output_node = None + if (self.dag_ir.cc >= 90): + # For SM90+, the lca should be the input node of D + if (not self.dag_ir.has_node("D")): + raise RuntimeError(f"D is not a node in the DAG IR.") + output_node = "D" + else: + output_node = potential_output_nodes[0] + + if (output_node is None): + raise RuntimeError(f"No output node found.") + lca = self.dag_ir.get_all_inputs(output_node)[0] + node_to_fuse.remove(output_node) + + # The lca is the output node of the DAG node + # Get the nodes to be fused + node_to_fuse.add(lca) + # Get all the input nodes + all_input_nodes = [] + all_output_nodes = [] + for node in node_to_fuse: + all_input_nodes.append(set(self.dag_ir.get_all_inputs(node))) + all_output_nodes.append(set(self.dag_ir.get_users(node))) + all_input_nodes = set.union(*all_input_nodes) + all_output_nodes = set.union(*all_output_nodes) + + new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes) + + # Create the subgraph + subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes) + subgraph = DAGIR(self.dag_ir.cc) + for node in subgraph_.nodes: + meta = deepcopy(self.dag_ir.get_node_meta(node)) + if node not in node_to_fuse: + meta.disabled = True + subgraph.add_node(meta) + for edge in subgraph_.edges: + subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1])) + + + # Create the fused node + dag_node = TopoVisitorNode( + name=f"dag_{lca}", subgraph=subgraph, + output_node=self.dag_ir.get_node_meta(lca)) + self.dag_ir.add_node(dag_node) + + # Add input edges + for idx, node in enumerate(all_input_nodes): + self.dag_ir.add_edge(node, dag_node.name, weight=idx) + + # Replace all uses with DAG node (only 1 output node) + self.dag_ir.replace_all_uses_with(lca, dag_node.name) + + # Remove all fused nodes + node_to_fuse.remove(lca) + for node in node_to_fuse: + self.dag_ir.remove_node(node) + + def ensures(self) -> None: + # Ensure that after the pass, the resulting DAG becomes a tree + for node in self.dag_ir.nodes: + out_degree = self.dag_ir.out_degree(node) + if out_degree > 1: + raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}") diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py new file mode 100644 index 0000000000000000000000000000000000000000..0d57c5b799d125ccc9491760259569731c0bf3ca --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py @@ -0,0 +1,64 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Fix the element_output of producer of D. + +In Sm90 epilogue visitor, the node writing D to gmem does not have internal +element converter, so the compute node producing D must have element_output = type(D). +""" + +from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase + + +class PassFixElementD(EVTPassBase): + """ + In Sm90 epilogue visitor, the node writing D to gmem does not have internal + element converter, so the compute node producing D must have + element_output = type(D) + """ + dependencies = [ + PassLayoutManipulateElimination + ] + def get_producer(self, node, element_D): + node_meta = self.dag_ir.get_node_meta(node) + if node_meta.op == "compute": + node_meta.element_output = element_D + elif node_meta.op == "store": + self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D) + + def call(self): + if self.dag_ir.has_node("D"): + node_d_meta = self.dag_ir.get_node_meta("D") + element_D = node_d_meta.store_tensor.element + self.get_producer("D", element_D) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..90fdafe7d0e80492bd2e641c69f11d95aace6bba --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py @@ -0,0 +1,90 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Infer the underlying implement of each node. + +While the frontend only distinguish between Load/Store/Compute Node, +each of these nodes can have different underlying implementation based +on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc. +This pass infers the underlying impl of each node +""" + +import cutlass_cppgen.backend.evt.backend as evt_backend +from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode +from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.passes.util import cc_map + + +class PassGetImpl(EVTPassBase): + """ + While the frontend only distinguish between Load/Store/Compute Node, + each of these nodes can have different underlying implementation based + on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc. + This pass infers the underlying impl of each node + """ + dependencies = [ + PassShapeTypePropagation, # The shape and type info are required for inference + PassFixElementD + ] + + def __init__(self, dag_ir: DAGIR) -> None: + super().__init__(dag_ir) + self.no_op_elimination = PassNoOpElimination(dag_ir) + + def requires(self) -> None: + # Verify "accum" is in the arg list + if not self.dag_ir.has_node("accum"): + raise SyntaxError("Cannot find 'accum' in the argument list.") + + def call(self): + # The loop structure of the epilogue is determined by the + # accumulator shape + accumulator: LoadNode = self.dag_ir.get_node_meta("accum") + problem_size = accumulator.tensor.shape + + for node_meta in self.dag_ir.node_metas_topological_order(): + node_meta.get_underlying_impl(problem_size) + + def ensures(self) -> None: + # Some nodes will be lowered to NoOp, eliminate them + self.no_op_elimination() + # Lower to cc-specific impl + for node_meta in self.dag_ir.nodes_meta: + node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes") + node_meta.underlying_impl = getattr( + node_impl_ccs, + f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__ + )(node_meta) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py new file mode 100644 index 0000000000000000000000000000000000000000..af147969f016b50ef05034fca99b173777948622 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py @@ -0,0 +1,217 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Eliminate layout manipulation nodes +""" + +from copy import deepcopy + +from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation + + +class PassLayoutManipulateElimination(EVTPassBase): + """ + Eliminate layout manipulation nodes + """ + dependencies = [PassShapeTypePropagation] + + def __init__(self, dag_ir: DAGIR) -> None: + super().__init__(dag_ir) + self.copy_cnt = 0 + + def call(self): + self.layout_nodes_worklist = self.get_all_layout_nodes() + # Run while loop utill all layout nodes are eliminated + while(len(self.layout_nodes_worklist) > 0): + node = self.layout_nodes_worklist.pop(0) + # for node in layout_nodes: + # Step 1: get the propagation direction + direction = self.get_propagation_direction(node) + self.visited = [] + getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node) + # Eliminate the current node + input_node = self.dag_ir.get_all_inputs(node)[0] + self.dag_ir.replace_all_uses_with(node, input_node) + # layout_nodes = self.get_all_layout_nodes() + + def get_all_layout_nodes(self): + layout_nodes = [] + for node_meta in reversed(self.dag_ir.node_metas_topological_order()): + if isinstance(node_meta, LayoutNode): + layout_nodes.append(node_meta.name) + return layout_nodes + + def get_propagation_direction(self, node: str): + """ + The logic is propagating all layout nodes away from the accumulator node. + """ + self.visited = [] + self.get_influenced_users(node) + nodes_influenced_dir_users = self.visited + self.visited = [] + self.get_influenced_inputs(node) + nodes_influenced_dir_inputs = self.visited + + if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs: + return "inputs" + elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs: + return "users" + else: + raise RuntimeError("Unsolved propagation direction") + + # Get all influenced nodes if we propagate along the user direction + def get_influenced_users(self, node: str): + if node in self.visited: + return + self.visited.append(node) + + users = self.dag_ir.get_users(node) + for user in users: + self.get_influenced_users(user) + user_inputs = [] + for user in users: + user_inputs.append(set(self.dag_ir.get_all_inputs(user))) + if len(user_inputs) > 0: + user_inputs = set.union(*user_inputs) + user_inputs.remove(node) + for input in user_inputs: + self.get_influenced_inputs(input) + + # Get all influenced nodes if we propagate along the input direction + def get_influenced_inputs(self, node: str): + if node in self.visited: + return + self.visited.append(node) + + inputs = self.dag_ir.get_all_inputs(node) + for input in inputs: + self.get_influenced_inputs(input) + input_users = [] + for input in inputs: + input_users.append(set(self.dag_ir.get_users(input))) + if len(input_users) > 0: + input_users = set.union(*input_users) + input_users.remove(node) + for user in input_users: + self.get_influenced_users(user) + + def add_copy_before(self, layout_node_meta: LayoutNode, target: str): + copied_node_meta = deepcopy(layout_node_meta) + copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}" + self.copy_cnt += 1 + copied_node_meta.name = copied_node + self.dag_ir.add_node(copied_node_meta) + # Add edges + target_inputs = self.dag_ir.get_all_inputs(target) + for src in target_inputs: + self.dag_ir.remove_edge(src, target) + self.dag_ir.add_edge(src, copied_node) + self.dag_ir.add_edge(copied_node, target) + self.layout_nodes_worklist.append(copied_node) + + def add_copy_after(self, layout_node_meta: LayoutNode, target: str): + copied_node_meta = deepcopy(layout_node_meta) + copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}" + self.copy_cnt += 1 + copied_node_meta.name = copied_node + self.dag_ir.add_node(copied_node_meta) + # Add edges + users = self.dag_ir.get_users(target) + for user in users: + self.dag_ir.remove_edge(target, user) + self.dag_ir.add_edge(copied_node, user) + self.dag_ir.add_edge(target, copied_node) + self.layout_nodes_worklist.append(copied_node) + + # Propagate the layout `node` along the user direction + def propagate_to_users(self, layout_node_meta: LayoutNode, node: str): + """ + Propagate layout node to users + """ + if node in self.visited: + # Avoid applying twice + return + self.visited.append(node) + + node_meta = self.dag_ir.get_node_meta(node) + if layout_node_meta.name != node: + if isinstance(node_meta, LayoutNode): + # Layout node is not transparent with layout node + self.add_copy_before(layout_node_meta, node) + return + else: + layout_node_meta.apply_to_user(node_meta) + + users = self.dag_ir.get_users(node) + user_inputs = [] + for user in users: + user_inputs.append(set(self.dag_ir.get_all_inputs(user))) + for user in users: + self.propagate_to_users(layout_node_meta, user) + if len(user_inputs) > 0: + user_inputs = set.union(*user_inputs) + user_inputs.remove(node) + for input in user_inputs: + self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input) + + # Propagate the layout `node` along the input direction + def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str): + """ + Propagate layout node to inputs + """ + if node in self.visited: + # Avoid applying twice + return + self.visited.append(node) + + node_meta = self.dag_ir.get_node_meta(node) + if layout_node_meta.name != node: + if isinstance(node_meta, LayoutNode): + # Layout node is not transparent with layout node + self.add_copy_after(layout_node_meta, node) + return + else: + layout_node_meta.apply_to_input(node_meta) + inputs = self.dag_ir.get_all_inputs(node) + input_users = [] + for input in inputs: + input_users.append(set(self.dag_ir.get_users(input))) + for input in inputs: + self.propagate_to_inputs(layout_node_meta, input) + if len(input_users) > 0: + input_users = set.union(*input_users) + input_users.remove(node) + for user in input_users: + self.propagate_to_users(layout_node_meta.get_inverse_node(), user) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b46bddb06e7c20be6d20526792777edef64b90 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py @@ -0,0 +1,164 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Pass manager for DAG IR. +""" + +from typing import Any + +import networkx as nx + +from cutlass_cppgen.backend.evt.ir import DAGIR +from cutlass_cppgen.backend.evt.passes.util import cc_map + + +class EVTPassBase: + """ + Base class for EVT Passes + """ + dependencies = [] + def __init__(self, dag_ir: DAGIR) -> None: + self.dag_ir = dag_ir + self.cc = self.dag_ir.cc + + def requires(self) -> None: + """ + This function will be called before the pass is run. + """ + pass + + def call(self) -> None: + """ + The pass that is run through the self.dag_ir + """ + raise NotImplementedError( + f"__call__ is not overwritten in Pass {self.__class__.__name__}") + + def ensures(self) -> None: + """ + This function will be called after the pass is run. + """ + pass + + def __call__(self) -> Any: + self.requires() + self.call() + self.ensures() + + def cc_specific_method(self, func): + """ + This enables defining function that behaves differently under different cc + The simplest example of using this function is the following + + .. highlight:: python + .. code-block:: python + + class ExamplePass(EVTPassBase): + + def call(sekf): + # This automatically select the smXX_func based on current cc + self.cc_specific_method(self.func)() + + # Interface func, can be empty + def func(self): + pass + + # Sm90 specific func + def sm90_func(self): + // sm90 specific method + return + + # Sm80 specific func + def sm80_func(self): + // sm80 specific method + return + """ + func_name = f"sm{cc_map[self.cc]}_{func.__name__}" + if hasattr(self, func_name): + return getattr(self, func_name) + else: + raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}") + + +class EVTPassManager(nx.DiGraph): + """ + Topological-based Pass Manager. + Each registered pass has a list of dependencies. The pass manager organizes + the passes as a DAG and launch the compiler passes under topological order. + """ + def __init__(self, dag_ir: DAGIR, pass_list): + super().__init__() + self.dag_ir = dag_ir + for pass_cls in pass_list: + self.add_pass(pass_cls) + + self.sorted_passes = self.schedule() + + def get_callable(self, pass_name): + """ + Return the callable of the pass + """ + return self.nodes[pass_name]["callable"] + + def add_pass(self, pass_cls): + """ + Add a pass to the pass manager + :param pass_cls: the class of pass + :type pass_cls: derived class of EVTPassBase + """ + name = pass_cls.__name__ + pass_callable = pass_cls(self.dag_ir) + self.add_node(name, callable=pass_callable) + + def schedule(self): + """ + Schedule the added passes under topological order + """ + # Add edges + for pass_name in self.nodes: + callable = self.get_callable(pass_name) + for dependency_cls in callable.dependencies: + self.add_edge( + dependency_cls.__name__, + type(callable).__name__) + + # Topological sort + return list(nx.topological_sort(self)) + + def __call__(self) -> Any: + """ + Launch the registered passes + """ + for pass_name in self.sorted_passes: + callable = self.get_callable(pass_name) + callable() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py new file mode 100644 index 0000000000000000000000000000000000000000..13107eb1d11c9a436348a4e50a92e62ce6f8b312 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py @@ -0,0 +1,53 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +No op elimination node +""" + +from typing import Any + +from cutlass_cppgen.backend.evt.ir import NoOpImpl +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase + + +class PassNoOpElimination(EVTPassBase): + """ + The dead node elimination pass removes nodes with NoOpImpl in DAG IR + """ + dependencies = [] + + def call(self) -> Any: + for node in self.dag_ir.nodes_topological_order(): + node_meta = self.dag_ir.get_node_meta(node) + if isinstance(node_meta.underlying_impl, NoOpImpl): + self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0]) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py new file mode 100644 index 0000000000000000000000000000000000000000..6423a2b845dd643650cf99037178030bee6f0dbd --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py @@ -0,0 +1,97 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Preprocess the reduction nodes. + +The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store() +This pass fuses these into a single store node, and then replaces all uses of the +current node with the new store node. +""" + +from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase + + +class PassPreprocessRed(EVTPassBase): + """ + Preprocess red nodes + """ + + def call(self): + # Step 1: find the compute nodes with op=red + red_compute_nodes = [] + for node_meta in self.dag_ir.nodes_meta: + if isinstance(node_meta, ComputeNode): + if type(node_meta.fn) == tuple: + # To keep the frontend simple, the reduction nodes + # are parsed into compute nodes by default + # The simple heuristic to distinguish between compute + # and reduction node is that compute node is a single function, + # while the reduction node is a tuple of functions for + # in-register reduction and atomic global memory reduction + red_compute_nodes.append(node_meta.name) + + # Step 2: for each compute, merge it with the succeeding store + for node in red_compute_nodes: + # Verify + users = self.dag_ir.get_users(node) + inputs = self.dag_ir.get_all_inputs(node) + # Has a single user + assert len(users) == 1 + assert len(inputs) == 1 + user = users[0] + input = inputs[0] + + user_meta = self.dag_ir.get_node_meta(user) + # Must be a store node + assert isinstance(user_meta, StoreNode) + # With output degree == 0 + assert self.dag_ir.out_degree(user) == 0 + # Register the reduce op + node_meta = self.dag_ir.get_node_meta(node) + user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn + user_meta.element_compute = node_meta.element_compute + user_meta.round_style = node_meta.round_style + + # Replace all uses + self.dag_ir.remove_edge(input, node) + input_users = self.dag_ir.get_users(input) + for iu in input_users: + weight = self.dag_ir.get_edge_weight(input, iu) + self.dag_ir.add_edge(user, iu, weight) + self.dag_ir.remove_edge(input, iu) + self.dag_ir.add_edge(input, user) + self.dag_ir.remove_node(node) + + # Register the reduction name + self.dag_ir.reduction_names.append(user) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..cb90a82c8f637429d3c64b3d881eb30d02c8c804 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py @@ -0,0 +1,59 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Shape and type propagation pass +""" + +from cutlass_cppgen.backend.evt.ir.node import NodeBase +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed + + +class PassShapeTypePropagation(EVTPassBase): + """ + Propagate the shape and type of all nodes + """ + dependencies = [PassPreprocessRed] + + def call(self): + # Propagate the node shape and type + for node in self.dag_ir.nodes_topological_order(): + node_meta: NodeBase = self.dag_ir.get_node_meta(node) + input_node_metas = self.dag_ir.get_all_inputs_meta(node) + node_meta.type_propagation(input_node_metas) + node_meta.shape_propagation(input_node_metas) + + for node in reversed(self.dag_ir.nodes_topological_order()): + node_meta: NodeBase = self.dag_ir.get_node_meta(node) + input_node_metas = self.dag_ir.get_all_inputs_meta(node) + node_meta.broadcast_propagation(input_node_metas) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..8168c59733a5da15eacbbe583c890610655ecff5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py @@ -0,0 +1,319 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Compute the shared memory size in bytes +""" + +from math import gcd + +import cutlass_library +from pycute import flatten, shape_div, product + +import cutlass_cppgen +from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR +from cutlass_cppgen.backend.library import DataType, DataTypeSize + + +class GetSmemSize: + """ + Get the size in byte of shared memory used by the kernel + """ + def __init__(self, dag_ir: DAGIR) -> None: + self.dag_ir = dag_ir + self.cc = self.dag_ir.cc + + # + # Sm90 epilogue specific + # + + def sm90_epilogue_tile(self, tile_description): + # Get the epilogue tile size + schedule = tile_description.epilogue_schedule + if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized: + element_d = self.dag_ir.get_node_meta("D").element + nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32 + epi_tile_m = min(64, tile_description.threadblock_shape[0]) + epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1]) + epilogue_tile_mn = (epi_tile_m, epi_tile_n) + elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative: + epi_tile_m = min(128, tile_description.threadblock_shape[0]) + epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1]) + epilogue_tile_mn = (epi_tile_m, epi_tile_n) + else: + raise NotImplementedError(f"Unsupported schedule: {schedule}") + + # Get the pipeline stages + stages_d = 2 + epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)) + if self.dag_ir.has_node("C"): + element_c = self.dag_ir.get_node_meta("C").element + else: + element_c = None + + element_d = self.dag_ir.get_node_meta("D").element + if element_c == element_d: + reuse_smem_c = True + else: + reuse_smem_c = False + stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles + + # Record the epilogue tile + self.cta_tile_mnk = tuple(tile_description.threadblock_shape) + self.epilogue_tile_mn = epilogue_tile_mn + self.epi_tiles = epi_tiles + self.stages_c = stages_c + self.stages_d = stages_d + self.reuse_smem_c = reuse_smem_c + self.element_c = element_c + self.element_d = element_d + self.is_source_supported = element_c is not None + + def sm90_or_sm100_epilogue_smem_size(self, tile_description): + # Get the Fusion Storage + nodes = self.dag_ir.nodes_topological_order() + self.smem_types = {} + for node in nodes: + meta = self.dag_ir.get_node_meta(node) + if not meta.disabled: + self.smem_types[node] = meta.underlying_impl.get_smem_size( + self.cta_tile_mnk, self.epilogue_tile_mn, + self.stages_c, self.stages_d, self.epi_tiles) + if node == "D": + continue + if isinstance(meta, TopoVisitorNode): + self.get_dag_smem_type(node) + else: + self.get_evt_smem_type(node) + + thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0] + # Get the Tensor Storage + tensors = [] + if self.is_source_supported: + smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8 + tensors.append((smem_C, 128)) + else: + tensors.append((0, 1)) + if self.reuse_smem_c: + tensors.append((0, 128)) + else: + smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8 + tensors.append((smem_D, 128)) + tensors.append((thread_smem_size, 128)) + + tensor_smem_size = self.get_struct_size(tensors) + # Get pipeline storage size + # sizeof(uint64_t * stages_c * 2), alignment of uint64_t + # 2 is for FullBarrier and EmptyBarrier + pipeline_smem_size = (8 * self.stages_c * 2, 8) + + # get SharedStorage size + smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size]) + return smem_size[0] + + def sm90_epilogue_smem_size(self, tile_description): + """ + Compute the shared memory size of sm90 collective epilogue + """ + self.sm90_epilogue_tile(tile_description) + return self.sm90_or_sm100_epilogue_smem_size(tile_description) + + # + # Sm100 epilogue specific + # + + def sm100_epilogue_tile(self, tile_description): + cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1]) + mma_tile = cta_tile + + if tile_description.is_2sm: + cta_tile = (cta_tile[0] // 2, cta_tile[1]) + + if tile_description.is_2sm and mma_tile[0] == 128: + tmem_warps = (2, 2) + else: + tmem_warps = (4, 1) + + if self.dag_ir.has_node("C"): + element_c = self.dag_ir.get_node_meta("C").element + element_c_size = DataTypeSize[element_c] + else: + element_c = None + element_c_size = 0 + + element_d = self.dag_ir.get_node_meta("D").element + + DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void + + CtaM = cta_tile[0] + CtaN = cta_tile[1] + WarpM = tmem_warps[0] + WarpN = tmem_warps[1] + MaxBits = max(element_c_size, DataTypeSize[element_d]) + DpFull = 32 + M = min(CtaM, DpFull * WarpM) + + if DisableSource: + # Epilogues w/o residual load are less sensitive to smem allocation + # Target a fixed amount of compute per epilogue iteration + if MaxBits == 4: + # Make epilogue tile larger to reduce the epilogue iterations. + # 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. + ComputeElts = 8192 + Nperf = ComputeElts // M + else: + ComputeElts = 4096 + Nperf = ComputeElts // M + else: + # Epilogues w/ residual load are more sensitive to smem allocation + # Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize + if MaxBits == 32: + Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32 + elif MaxBits == 16: + Nperf = 32 if CtaN <= 128 else 64 + else: + Nperf = 64 + + def is_m_major(layout): + return flatten(layout.stride[0]) == 1 + + if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout): + N_min_C = 8 * WarpN + elif element_c_size == 6: + N_min_C = 128 * WarpN + else: + N_min_C = (128 // element_c_size) * WarpN + + if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout): + N_min_D = 8 * WarpN + elif DataTypeSize[element_d] == 6: + N_min_D = 128 * WarpN + else: + N_min_D = (128 // DataTypeSize[element_d]) * WarpN + + N = min(CtaN, max(Nperf, N_min_C, N_min_D)) + + tile_m = M + tile_n_size = N // WarpN * WarpN + + epilogue_tile_mn = (tile_m, tile_n_size) + epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)) + + stages_d = min(epi_tiles, 2) + reuse_smem_c = (element_c_size > 8) + + if reuse_smem_c: + stages_c = max(min(epi_tiles, 4), stages_d + 1) + else: + stages_c = min(epi_tiles, 4) + + # Record the epilogue tile + self.cta_tile_mnk = tuple(tile_description.threadblock_shape) + self.epilogue_tile_mn = epilogue_tile_mn + self.epi_tiles = epi_tiles + self.stages_c = stages_c + self.stages_d = stages_d + self.reuse_smem_c = reuse_smem_c + self.element_c = element_c + self.element_d = element_d + self.is_source_supported = not DisableSource + + def sm100_epilogue_smem_size(self, tile_description): + """ + Compute the shared memory size of sm100 collective epilogue + """ + self.sm100_epilogue_tile(tile_description) + return self.sm90_or_sm100_epilogue_smem_size(tile_description) + + def __call__(self, tile_description): + return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description) + + # + # Helper functions + # + + @staticmethod + def get_visitor_size(members: list, ebo: bool): + """ + Get the size of struct in bytes + """ + offset = 0 + max_alignment = 1 + if len(members) > 0: + # Get alignment + for _, alignment in members: + max_alignment = max(max_alignment, alignment) + + for type_size, _ in members: + if type_size != 0: + offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment + if type_size == 0 and not ebo: + offset += 1 + else: + offset += type_size + offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment + return (offset, max_alignment) + else: + # Struct size is at least 1 + return (1, 1) + + def get_struct_size(self, members: list): + """ + Get the size of struct in bytes + """ + return self.get_visitor_size(members, False) + + def get_evt_smem_type(self, node): + # Sort the input nodes by edge weight + input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)] + input_types.append(self.smem_types[node]) + if len(input_types) > 1: + ebo = len(input_types) > 4 + self.smem_types[node] = self.get_visitor_size(input_types, ebo) + + def get_dag_smem_type(self, node): + meta = self.dag_ir.get_node_meta(node) + subgraph = meta.subgraph + subgraph_nodes = subgraph.nodes_topological_order() + # Visit the unvisited nodes in subgraph + for n in subgraph_nodes: + m = subgraph.get_node_meta(n) + if m.disabled: + continue + else: + self.smem_types[n] = m.underlying_impl.get_smem_size( + self.cta_tile_mnk, self.epilogue_tile_mn, + self.stages_c, self.stages_d, self.epi_tiles) + input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]] + if len(input_types) > 0: + ebo = len(input_types) > 4 + self.smem_types[node] = self.get_visitor_size(input_types, ebo) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py new file mode 100644 index 0000000000000000000000000000000000000000..4b72e330523ca1e4fb8c5d4526289641e158e72e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py @@ -0,0 +1,46 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for passes +""" + +# Map from the CC of the kernel to the EVT implementation that the CC targets +cc_map = { + 80: 80, + 86: 80, + 89: 80, + 90: 90, + 100: 100, + 101: 100, + 103: 100, +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..a959976b8601b0793c4c7c1709d61c8c838df838 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py @@ -0,0 +1,109 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from __future__ import annotations + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +import numpy as np + +from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor + + +class NumpyFrontend: + """ + Frontend node for numpy + """ + + @staticmethod + def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr: + """Convert the input numpy tensor to CUDA device pointer + + :param np_tensor: input numpy nd array + :param is_output: whether the tensor is output + + :return: CUDA device pointer + """ + # copy the data to device + if is_output: + return device_mem_alloc(np_tensor.size * np_tensor.itemsize) + else: + return todevice(np_tensor) + + +class TorchFrontend: + """ + Frontend node for torch + """ + + @staticmethod + def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr: + """Convert the input torch tensor to CUDA device pointer + + :param torch_tensor: input torch tensor + :param is_output: whether the tensor is output + + :return: CUDA device pointer + """ + + # check the device of torch_tensor + if not torch_tensor.is_cuda: + torch_tensor = torch_tensor.to("cuda") + + return cuda.CUdeviceptr(torch_tensor.data_ptr()) + + +class CupyFrontend: + """ + Frontend node for cupy + """ + + @staticmethod + def argument(cupy_ndarray: "cp.ndarray"): + return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr)) + + +class TensorFrontend: + """ + Universal Frontend for client-provide tensors + """ + + @staticmethod + def argument(tensor, is_output=False): + if is_numpy_tensor(tensor): + return NumpyFrontend.argument(tensor, is_output) + elif is_torch_tensor(tensor): + return TorchFrontend.argument(tensor) + elif is_cupy_tensor(tensor): + return CupyFrontend.argument(tensor) + else: + raise NotImplementedError("Unknown Tensor Type") diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2a3a30a097eb45c691554daf70f8db12e5bc48 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py @@ -0,0 +1,2145 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from __future__ import annotations + +import copy +import ctypes +import enum + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +from cutlass_library import SubstituteTemplate +import numpy as np + +from cutlass_library import ( + ComplexTransformTag, + DataType, + DataTypeNames, + DataTypeSize, + DataTypeTag, + EpilogueScheduleSuffixes, + EpilogueScheduleTag, + EpilogueScheduleType, + GemmKind, + GemmKindNames, + GemmUniversalMode, + KernelScheduleSuffixes, + KernelScheduleTag, + KernelScheduleType, + LayoutTag, + LayoutType, + MathOperation, + MathOperationTag, + OpcodeClass, + OpcodeClassNames, + OpcodeClassTag, + OperationKind, + ShortComplexLayoutNames, + ShortDataTypeNames, + ShortLayoutTypeNames, + SwizzlingFunctor, + SwizzlingFunctorTag, + TileSchedulerSuffixes, + TileSchedulerTag, + TileSchedulerType, + get_complex_from_real +) +from cutlass_cppgen.backend.arguments import ArgumentBase +from cutlass_cppgen.backend.c_types import ( + GemmCoord_, + GemmCoordBatched_, + GenericMainloopArguments3x_, + StrideBatched_, + dim3_, + get_gemm_arguments, + get_gemm_arguments_3x, + get_gemm_arguments_streamk, + get_gemm_grouped_arguments, + get_mainloop_arguments_3x, + get_tile_scheduler_arguments_3x, +) +from cutlass_cppgen.backend.library import ( + ApiVersion, + EmissionType, + SchedulerMode, + SchedulerModeTag, + TensorDescription, + TileDescription, + api_version, +) +from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice +from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass_cppgen.backend.type_hint import GemmOperation, Tensor +from cutlass_cppgen.backend.utils.device import device_sm_count +from cutlass_cppgen.shape import GemmCoord, MatrixCoord + + +################################################################################ +# +# Data structure modeling a GEMM operation +# +################################################################################ + + +def leading_dimension(layout: LayoutType, shape: MatrixCoord) -> int: + """ + Returns the leading dimenson of a tensor with layout ``layout`` and shape ``shape``. + + :param layout: layout of the tensor + :type layout: cutlass_cppgen.shape.LayoutType + :param shape: shape of the tensor + :type shape: cutlass_cppgen.shape.MatrixCoord + + :return: leading dimension of the tensor + :rtype: int + """ + if layout == LayoutType.RowMajor: + return shape.column + elif layout == LayoutType.ColumnMajor: + return shape.row + + +def transpose_layout(layout: LayoutType) -> LayoutType: + if layout == LayoutType.ColumnMajor: + return LayoutType.RowMajor + elif layout == LayoutType.RowMajor: + return LayoutType.ColumnMajor + else: + raise ValueError(f"Unsupported Layout {layout}") + + +class GemmArguments2x(ArgumentBase): + """ + Argument wrapper for GEMM in CUTLASS 2. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` + + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + """ + + def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): + self.operation = operation + + self.layout_A = operation.A.layout + self.layout_B = operation.B.layout + self.layout_C = operation.C.layout + + self.element_A = operation.A.element + self.element_B = operation.B.element + self.element_C = operation.C.element + + if operation.C.layout in [LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32]: + raise Exception("Interleaved layout not currently supported") + + if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch not in [90, 100, 101, 103]: + super().__init__(A, B, None, None, **kwargs) + else: + super().__init__(A, B, C, D, **kwargs) + + if operation.switched: + self.problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k) + self.ptr_A, self.ptr_B = self.ptr_B, self.ptr_A + else: + self.problem_size = problem_size + # If the number of elements in C = problem_size.n, C is treated as the bias + if hasattr(self, "tensor_c_numel"): + if self.tensor_c_numel == self.problem_size.n and self.problem_size.m != 1: + self.bias = True + + self.lda = leading_dimension(self.layout_A, self.problem_size.mk) + self.ldb = leading_dimension(self.layout_B, self.problem_size.kn) + self.ldc = leading_dimension(self.layout_C, self.problem_size.mn) + self.ldd = self.ldc + + if self.bias: + self.ldc = 0 + + if "output_op" in kwargs.keys() and gemm_mode != GemmUniversalMode.GemmSplitKParallel: + self.output_op = kwargs["output_op"] + else: + if self.operation.epilogue_functor.element_epilogue in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]: + dtype = int + else: + dtype = float + self.output_op = self.operation.epilogue_type(dtype(1.0), dtype(0.0)) + + self.gemm_mode = gemm_mode + if gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]: + if "split_k_slices" in kwargs.keys(): + self.batch_count = kwargs["split_k_slices"] + else: + self.batch_count = 1 + self.split_k_slices = self.batch_count + + if gemm_mode in [GemmUniversalMode.Batched, GemmUniversalMode.Array]: + if "batch" in kwargs.keys(): + self.batch_count = kwargs["batch"] + else: + self.batch_count = 1 + + if "batch_strides" in kwargs: + self.batched_stride_A = kwargs["batch_strides"]["A"] + self.batched_stride_B = kwargs["batch_strides"]["B"] + self.batched_stride_C = kwargs["batch_strides"]["C"] + self.batched_stride_D = kwargs["batch_strides"]["D"] + else: + self.batched_stride_A = self.problem_size.m * self.problem_size.k + self.batched_stride_B = self.problem_size.n * self.problem_size.k + self.batched_stride_C = self.problem_size.m * self.problem_size.n + self.batched_stride_D = self.problem_size.m * self.problem_size.n + + if self.bias: + self.batched_stride_C = self.problem_size.n + + if gemm_mode == GemmUniversalMode.Array: + self.ptr_A_array = [] + self.ptr_B_array = [] + self.ptr_C_array = [] + self.ptr_D_array = [] + + ptr_A_addr = int(self.ptr_A) + ptr_B_addr = int(self.ptr_B) + ptr_C_addr = int(self.ptr_C) + ptr_D_addr = int(self.ptr_D) + + stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8 + stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8 + stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8 + stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8 + for _ in range(self.batch_count): + self.ptr_A_array.append(ptr_A_addr) + self.ptr_B_array.append(ptr_B_addr) + self.ptr_C_array.append(ptr_C_addr) + self.ptr_D_array.append(ptr_D_addr) + + ptr_A_addr += stride_A + ptr_B_addr += stride_B + ptr_C_addr += stride_C + ptr_D_addr += stride_D + + self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64) + self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64) + self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64) + self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64) + + if isinstance(self.operation, GemmOperationUniversal): + self.initialize() + + def get_arguments(self): + problem_size_ = self.problem_size.ctype + grid_tiled_shape_ = GemmCoord( + self.grid_tiled_shape.x, + self.grid_tiled_shape.y, + self.grid_tiled_shape.z ).ctype + + if self.gemm_mode == GemmUniversalMode.Array: + arguments = self.operation.argument_type( + # Arguments from UniversalArgumentsBase + self.gemm_mode, + problem_size_, + self.batch_count, + 0, + # Remaining arguments + self.output_op, + int(self.ptr_A_array_buffer.ptr), + int(self.ptr_B_array_buffer.ptr), + int(self.ptr_C_array_buffer.ptr), + int(self.ptr_D_array_buffer.ptr), + 0, 0, 0, + self.lda, self.ldb, self.ldc, self.ldd, + self.lda, self.ldb, self.ldc, self.ldd, + 0, 0, 0 + ) + else: + arguments = self.operation.argument_type( + # Arguments from UniversalArgumentsBase + self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D, + # Remaining arguments + self.output_op, + int(self.ptr_A), + int(self.ptr_B), + int(self.ptr_C), + int(self.ptr_D), + self.batched_stride_A, + self.batched_stride_B, + self.batched_stride_C, + self.lda, self.ldb, self.ldc, self.ldd, + self.lda, self.ldb, self.ldc, self.ldd, + 0, 0, 0 + ) + + self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size + + def initialize(self): + launch_config = self.operation.rt_module.plan(self) + + # Get the host and device workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + device_workspace = 0 + if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel: + # In GEMM splik-K parallel, the D pointer is redirected to the workspace + self.ptr_D = cuda.CUdeviceptr(workspace_ptr) + elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm: + device_workspace = workspace_ptr + + self.get_arguments() + + arguments, grid_tiled_shape, gemm_k_size = self.arguments + res_arg = self.operation.rt_module.get_args( + ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace))) + host_workspace = bytearray(res_arg.contents) + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = launch_config + + def sync(self, stream_sync=True): + super().sync(stream_sync) + if hasattr(self.output_op, "sync"): + self.output_op.sync() + + +class GemmArguments2xStreamK(GemmArguments2x): + """ + Argument wrapper for stream-K GEMMs in CUTLASS 2. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` + + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + """ + + def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): + if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]: + raise Exception(f"Unsupported GEMM mode {gemm_mode}.") + + super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) + + def get_arguments(self): + batch_stride_A = self.problem_size.m * self.problem_size.k + batch_stride_B = self.problem_size.k * self.problem_size.n + batch_stride_C = self.problem_size.m * self.problem_size.n + batch_stride_D = self.problem_size.m * self.problem_size.n + + arguments = self.operation.argument_type( + self.gemm_mode, + GemmCoord_(self.problem_size.m, self.problem_size.n, self.problem_size.k), + self.batch_count, + self.output_op, + int(self.ptr_A), + int(self.ptr_B), + int(self.ptr_C), + int(self.ptr_D), + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D, + self.lda, self.ldb, self.ldc, self.ldd, # strides + self.lda, self.ldb, self.ldc, self.ldd, + -1, # avail_sms + ) + return arguments + + def initialize(self): + # Get the host and device workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size( + self, + device_sm_count(), + self.operation.rt_module.occupancy + ) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + device_workspace = 0 + if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel: + # In GEMM splik-K parallel, the D pointer is redirected to the workspace + self.ptr_D = cuda.CUdeviceptr(workspace_ptr) + elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm: + device_workspace = workspace_ptr + + arguments = self.get_arguments() + + res_arg = self.operation.rt_module.get_args( + ctypes.byref(arguments), + ctypes.c_void_p(int(device_workspace)), + device_sm_count(), + self.operation.rt_module.occupancy + ) + host_workspace = bytearray(res_arg.contents) + + grid = self.operation.rt_module.get_grid_shape( + ctypes.byref(arguments), + device_sm_count(), + self.operation.rt_module.occupancy + ) + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = LaunchConfiguration( + [grid.m, grid.n, grid.k], + [self.operation.rt_module.threads, 1, 1], + self.operation.rt_module.shared_memory_capacity + ) + + +class GemmArguments3x(GemmArguments2x): + """ + Argument wrapper for GEMM in CUTLASS 3. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: GemmUniversalMode + + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + """ + + def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): + if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]: + raise Exception(f"Unsupported GEMM mode {gemm_mode}.") + + super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) + + def get_arguments(self): + mainloop_args = get_mainloop_arguments_3x( + self.operation.tile_description.kernel_schedule, + self.operation.A.element, + self.operation.B.element, + self.operation.A.alignment, + self.operation.B.alignment + ) + scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler) + uses_default_epilogue = self.operation.rt_module.uses_default_epilogue() + argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x( + mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue) + + problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count) + + if self.batch_count > 1: + bsA = self.batched_stride_A + bsB = self.batched_stride_B + bsC = self.batched_stride_C + bsD = self.batched_stride_D + else: + bsA = 0 + bsB = 0 + bsC = 0 + bsD = 0 + stride_A = StrideBatched_(self.lda, bsA) + stride_B = StrideBatched_(self.ldb, bsB) + stride_C = StrideBatched_(self.ldc, bsC) + stride_D = StrideBatched_(self.ldd, bsD) + + # Superset of potential mainloop arguments + generic_args = GenericMainloopArguments3x_( + int(self.ptr_A), + stride_A, + int(self.ptr_B), + stride_B, + 4 # mma_promotion_interval + ) + + # Set of mainloop arguments needed for this kernel + mainloop = mainloop_args.from_generic_mainloop_args(generic_args) + + if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"): + self.output_op = self.output_op.to_evt_params() + + epilogue = epilogue_args( + self.output_op, + int(self.ptr_C), + stride_C, + int(self.ptr_D), + stride_D, + ) + + # Set hardware info + hw_info_ = hw_info( + 0, device_sm_count(), 0, + dim3_(0,0,0), + dim3_(0,0,0), + ) + + self.arguments = argument_type( + int(self.gemm_mode), + problem_size_, + mainloop, + epilogue, + hw_info_, + scheduler_args + ) + return self.arguments + + def initialize(self): + # Get the host and evice workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + device_workspace = 0 + if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel: + # In GEMM splik-K parallel, the D pointer is redirected to the workspace + self.ptr_D = cuda.CUdeviceptr(workspace_ptr) + elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm: + device_workspace = workspace_ptr + + self.get_arguments() + res_arg = self.operation.rt_module.get_args( + ctypes.byref(self.arguments), + ctypes.c_void_p(int(device_workspace)), + ) + host_workspace = bytearray(res_arg.contents) + + grid = self.operation.rt_module.get_grid_shape( + ctypes.byref(self.arguments), + ctypes.c_void_p(int(device_workspace)), + ) + block = self.operation.rt_module.get_block_shape() + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = LaunchConfiguration( + [grid.x, grid.y, grid.z], + [block.x, block.y, block.z], + self.operation.rt_module.shared_memory_capacity, + ) + + +def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): + """ + Argument wrapper for GEMM in CUTLASS 2 or 3. It returns either 2x arguments + or 3x arguments depending on the `arch` field specified in `operation`. + + :param operation: the GEMM operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` + + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + """ + if operation.swizzling_functor == SwizzlingFunctor.StreamK: + if operation.api == ApiVersion.v3x: + raise Exception("Stream K is currently only supported in CUTLASS 2.x") + ArgClass = GemmArguments2xStreamK + else: + ArgClass = GemmArguments3x if operation.api == ApiVersion.v3x else GemmArguments2x + return ArgClass(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) + + +class GemmGroupedArguments: + """ + Argument wrapper for GEMM Grouped. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM Grouped operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.GemmOperationGrouped` + + :param problem_size: list of GEMM problem size gemm(M, N, K) + :type operation: list[:class:`cutlass_cppgen.shape.GemmCoord`] + + :param A: list of tensor A + :type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param B: list of tensor B + :type B: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param C: list of tensor C + :type C: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param D: list of tensor D + :type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + """ + + def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs): + # Get number of problems in the group + self.problem_count = len(problem_sizes) + + # Check the input arguments + assert len(A) == self.problem_count + assert len(B) == self.problem_count + assert len(C) == self.problem_count + assert len(D) == self.problem_count + + problem_size_host = [] + self.ptr_A_host = [] + self.ptr_B_host = [] + self.ptr_C_host = [] + self.ptr_D_host = [] + + lda_host = [] + ldb_host = [] + ldc_host = [] + ldd_host = [] + + self.partitions = 1 + + self.operation = operation + + # Get the threadblock + threadblock_shape = operation.tile_description.threadblock_shape + self.threadblock_shape = GemmCoord( + threadblock_shape[0], + threadblock_shape[1], + threadblock_shape[2], + ) + self.threadblock_swizzle = operation.swizzling_functor + + self.total_tiles = 0 + + self.gemm_arguments = [] + + self.stream = kwargs.get("stream", cuda.CUstream(0)) + + # Process the input arguments + for idx, problem_size in enumerate(problem_sizes): + M, N, K = problem_size.m, problem_size.n, problem_size.k + temp_argument = GemmArguments2x( + operation=operation, + problem_size=GemmCoord(M, N, K), + A=A[idx], B=B[idx], C=C[idx], D=D[idx]) + self.gemm_arguments.append(temp_argument) + + problem_size_host.append( + [temp_argument.problem_size.m, + temp_argument.problem_size.n, + temp_argument.problem_size.k] + ) + + self.ptr_A_host.append(int(temp_argument.ptr_A)) + lda_host.append(temp_argument.lda) + + self.ptr_B_host.append(int(temp_argument.ptr_B)) + ldb_host.append(temp_argument.ldb) + + self.ptr_C_host.append(int(temp_argument.ptr_C)) + ldc_host.append(temp_argument.ldc) + + self.ptr_D_host.append(int(temp_argument.ptr_D)) + ldd_host.append(temp_argument.ldd) + + # Get number of tiles + grid = self.operation.rt_module.get_grid_shape( + self.operation.rt_module.get_tiled_shape( + temp_argument.problem_size.ctype, + self.threadblock_shape.ctype, + temp_argument.batch_count + ) + ) + self.total_tiles += grid.x * grid.y * grid.z + + self.problem_size_buffer = todevice(problem_size_host, np.int32) + self.ptr_A_buffer = todevice(self.ptr_A_host, np.int64) + self.ptr_B_buffer = todevice(self.ptr_B_host, np.int64) + self.ptr_C_buffer = todevice(self.ptr_C_host, np.int64) + self.ptr_D_buffer = todevice(self.ptr_D_host, np.int64) + + self.lda_buffer = todevice(lda_host, np.int64) + self.ldb_buffer = todevice(ldb_host, np.int64) + self.ldc_buffer = todevice(ldc_host, np.int64) + self.ldd_buffer = todevice(ldd_host, np.int64) + + if "output_op" in kwargs.keys(): + self.alpha = kwargs["output_op"].alpha + self.beta = kwargs["output_op"].beta + else: + self.alpha = 1.0 + self.beta = 0.0 + + if "output_op" in kwargs.keys(): + self.output_op = kwargs["output_op"] + else: + self.output_op = self.operation.epilogue_type(1.0, 0.0) + + # Get host problem size + self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0] + + self.arguments = self.get_arguments() + + self.initialize() + + def get_arguments(self): + return self.operation.argument_type( + self.problem_size_buffer.ptr, + self.problem_count, + self.total_tiles, + self.output_op, + self.ptr_A_buffer.ptr, + self.ptr_B_buffer.ptr, + self.ptr_C_buffer.ptr, + self.ptr_D_buffer.ptr, + self.lda_buffer.ptr, + self.ldb_buffer.ptr, + self.ldc_buffer.ptr, + self.ldd_buffer.ptr, + ctypes.c_void_p(int(self.host_problem_size_ptr)), + ) + + def initialize(self): + # Get launch configuration + launch_config = self.operation.rt_module.plan(self) + + # Get the host and evice workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + if self.operation.precompute_mode == SchedulerMode.Host: + device_workspace_ptr = self.operation.rt_module.host_precompute( + self, self.operation.rt_module.get_workspace_size(self),) + else: + device_workspace_ptr = 0 + + result = self.operation.rt_module.get_args( + ctypes.byref(self.arguments), + self.total_tiles, + ctypes.c_void_p(int(device_workspace_ptr)), + ) + host_workspace = bytearray(result.contents) + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = launch_config + + def sync(self): + err, = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + for arg in self.gemm_arguments: + arg.sync(stream_sync=False) + + +################################################################################ +# Base class for GEMM runtime module +################################################################################ + + +class GemmRTbase(ExecutableOperation): + """ + GemmRT manages the CUTLASS runtime components + """ + + KernelTemplate = r""" +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix}::invoke(params, *shared_storage); +} + """ + + def __init__(self, operation: "GemmOperation"): + super().__init__(operation) + + self.operation = operation + threadblock_shape = operation.tile_description.threadblock_shape + self.threadblock_shape = GemmCoord( + threadblock_shape[0], threadblock_shape[1], threadblock_shape[2]) + self.threadblock_swizzle = operation.swizzling_functor + + # Threads per threadblock + self.threads = operation.tile_description.num_threads + + def emit(self): + return self.emitter.emit(self.operation) + + def can_implement(self, configuration, arguments): + raise NotImplementedError() + + def get_host_workspace_size(self, arguments): + raise NotImplementedError() + + def get_device_workspace_size(self, arguments): + return 0 + + def initialize(self): + err, = cuda.cuFuncSetAttribute( + self.kernel, + attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + value=self.shared_memory_capacity) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError( + f"CUDA error on call to cuFuncSetAttribute: {cuda.cuGetErrorString(err)[1]}" + ) + + +################################################################################ +# Runtime module for GEMM Universal +################################################################################ + + +class GemmRTUniversal(GemmRTbase): + """ + GemmRTUniversal manages the CUTLASS runtime components + """ + + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){ + ${operation_name}_base::Params* params; + params = new ${operation_name}_base::Params(*argument, + -1, // SM count. Only used for stream-K + -1 // Occupancy. Only used for stream-K + ); + + // Semaphore holds the pointer to the workspace in the Params struct + params->semaphore = workspace; + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}_base::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape( + cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) { + return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape( + problem_size, tile_size, split_k_slices); + } + + dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) { + return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape); + } +} + """ + + def __init__(self, operation): + super(GemmRTUniversal, self).__init__(operation) + self.extra_funcs = { + "get_tiled_shape": GemmCoord_, + "get_grid_shape": dim3_, + } + self.emitter = EmitGemmUniversalInstance( + "_type", operation.direct_store) + + self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor) + self.argtype = [ + ctypes.POINTER(self.argument_type), + ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p + ] + + def plan(self, arguments): + grid = self.get_tiled_shape( + arguments.problem_size.ctype, + self.threadblock_shape.ctype, + arguments.batch_count + ) + + gemm_k_size = arguments.problem_size.k + if arguments.gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]: + alignk = max(max(128 // DataTypeSize[self.operation.A.element], + 128 // DataTypeSize[self.operation.B.element]), 1) + + gemm_k_size = (((arguments.problem_size.k + arguments.batch_count - 1) // + arguments.batch_count + alignk - 1) // alignk) * alignk + + if gemm_k_size: + grid_z = (arguments.problem_size.k + gemm_k_size - 1) // gemm_k_size + grid = GemmCoord(grid.m, grid.n, grid_z).ctype + + arguments.grid_tiled_shape = dim3_(grid.m, grid.n, grid.k) + grid = self.get_grid_shape(grid) + arguments.gemm_k_size = gemm_k_size + return LaunchConfiguration( + [grid.x, grid.y, grid.z], + [self.threads, 1, 1], + self.shared_memory_capacity) + + def get_device_workspace_size(self, arguments: GemmArguments): + workspace_bytes = 0 + if arguments.gemm_mode == GemmUniversalMode.GemmSplitKParallel: + workspace_bytes = (DataTypeSize[arguments.operation.C.element] + * arguments.batched_stride_D * arguments.grid_tiled_shape.z // 8) + elif (arguments.gemm_mode == GemmUniversalMode.Gemm and + arguments.split_k_slices > 1): + workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y + + return workspace_bytes + + +class GemmRTUniversalStreamK(GemmRTUniversal): + """ + Manages the CUTLASS runtime components for 2.x stream K kernels + """ + + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + using GemmType = ${operation_name}_base; + + // Get the params as byte array + char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace, + int sm_count, int occupancy) { + GemmType::Params* params; + params = new GemmType::Params(*argument, sm_count, occupancy); + + params->init_workspace(workspace); + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(GemmType::Params)]; + for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int device_sms, int sm_occupancy) { + typename GemmType::Params params(*args, device_sms, sm_occupancy); + return params.get_grid_dims(); + } + + uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* args, int device_sms, int sm_occupancy) { + typename GemmType::Params params(*args, device_sms, sm_occupancy); + return params.get_workspace_size(); + } +} + """ + + def __init__(self, operation: "GemmOperation"): + super(GemmRTUniversalStreamK, self).__init__(operation) + self.extra_funcs = { + "get_grid_shape": GemmCoord_, + "get_kernel_workspace_size": ctypes.c_uint64, + } + self._occupancy = None + self.argument_type, self.epilogue_type = get_gemm_arguments_streamk(operation.epilogue_functor) + + @property + def occupancy(self): + if self._occupancy is None: + err, self._occupancy = cuda.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + self.kernel, self.threads, self.shared_memory_capacity, + cuda.CUoccupancy_flags.CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError( + "CUDA error on call to cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: " + f"{cuda.cuGetErrorString(err)[1]}") + return self._occupancy + + def get_device_workspace_size(self, arguments: GemmArguments2xStreamK, device_sms: int, sm_occupancy: int): + return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()), device_sms, sm_occupancy) + + +################################################################################ +# Runtime module for GEMM Universal within CUTLASS 3 +################################################################################ + + +class GemmRTUniversal3x(GemmRTUniversal): + """ + Manages the CUTLASS runtime components for 3.x kernels + """ + + KernelTemplate = r""" + +using Operator = ${operation_name}${operation_suffix}; +extern "C" +__global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +void ${operation_name}(__grid_constant__ typename Operator::Params const params) { + // Dynamic shared memory base pointer + extern __shared__ char smem[]; + + // Declare pointer to dynamic shared memory. + Operator op; + op(params, smem); +} + """ + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return ${operation_name}${operation_suffix}::SharedStorageSize; + } + + using GemmType = ${operation_name}_base; + + bool ${operation_name}_uses_default_epilogue() { + return std::is_same_v; + } + + // Get the workspace size + uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) { + return GemmType::get_workspace_size(*argument); + } + + // Get the params as byte array + char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){ + GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace); + char *bytes = ((char*)(¶ms)); + char *output = new char[sizeof(GemmType::Params)]; + for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + // Get the total number of blocks for a persistent kernel + uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) { + auto problem_shape_MNKL = append<4>(problem, Int<1>{}); + auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = + cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl( + problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{}); + return problem_blocks_m * problem_blocks_n * problem_blocks_l; + } + + // Get the grid shape + dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int* workspace) { + auto tmp_params = GemmType::to_underlying_arguments(*args, workspace); + return GemmType::get_grid_shape(tmp_params); + } + + // Get the block shape + dim3 ${operation_name}_get_block_shape() { + return GemmType::get_block_shape(); + } +} + """ + + def __init__(self, operation): + super(GemmRTUniversal3x, self).__init__(operation) + self.extra_funcs = { + "get_grid_shape": dim3_, + "get_block_shape": dim3_, + "get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64, + "get_kernel_workspace_size": ctypes.c_uint64, + "uses_default_epilogue": ctypes.c_bool, + } + self.emitter = EmitGemmUniversalInstance3x("_type") + + def get_device_workspace_size(self, arguments: GemmArguments3x): + return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments())) + + +class EmitGemmUniversalInstance3x: + """Responsible for emitting a CUTLASS 3 template definition""" + + def __init__(self, operation_suffix=""): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cute/tensor.hpp", + "cute/atom/mma_atom.hpp", + "cutlass/numeric_types.h", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/gemm/kernel/sm90_tile_scheduler.hpp", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + "cutlass/epilogue/collective/default_epilogue.hpp", + "cutlass/epilogue/thread/linear_combination.h" + ] + self.gemm_template_kernel = """ +using namespace cute; + +using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + ${epilogue_schedule} + >::CollectiveOp; + +using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape, + ${stage_count_type}, + ${kernel_schedule} + >::CollectiveOp; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + ${tile_scheduler} +>; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_kernel_visitor = """ +using namespace cute; + +${callback_decl} + +using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ${element_accumulator}, ${element_epilogue}, + ElementC, StrideC, ${align_c}, + ElementD, StrideD, ${align_d}, + ${epilogue_schedule}, + ${callback_name} + >::CollectiveOp; + +using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape, + ${stage_count_type}, + ${kernel_schedule} + >::CollectiveOp; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + ${tile_scheduler} +>; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + self.gemm_template_device = self.gemm_template_kernel + """ + +// Define device-level operator +using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}${operation_suffix}>; +""" + + def emit(self, operation): + # Support built-in epilogue functors or user-defined functions + + if operation.tile_description.stages is None or operation.tile_description.stages == 0: + stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>" + else: + stage_count_type = "_" + str(operation.tile_description.stages) + + if operation.emission_type == EmissionType.Kernel: + gemm_template = self.gemm_template_kernel + else: + gemm_template = self.gemm_template_device + + kschedule = KernelScheduleType.ScheduleAuto + eschedule = EpilogueScheduleType.ScheduleAuto + tschedule = TileSchedulerType.Default + if operation.tile_description.kernel_schedule is not None: + kschedule = operation.tile_description.kernel_schedule + if operation.tile_description.epilogue_schedule is not None: + eschedule = operation.tile_description.epilogue_schedule + if operation.tile_description.tile_scheduler is not None: + tschedule = operation.tile_description.tile_scheduler + + emit_tile_m, emit_tile_n, emit_tile_k = operation.tile_description.blackwell_threadblock_shape + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_d": DataTypeTag[operation.epilogue_functor.element_output], + "layout_d": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "element_epilogue": DataTypeTag[operation.epilogue_functor.element_epilogue], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(emit_tile_m), + "threadblock_shape_n": str(emit_tile_n), + "threadblock_shape_k": str(emit_tile_k), + "cluster_m": str(operation.tile_description.cluster_shape[0]), + "cluster_n": str(operation.tile_description.cluster_shape[1]), + "cluster_k": str(operation.tile_description.cluster_shape[2]), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "align_c": str(operation.C.alignment), + "align_d": str(operation.C.alignment), + "stage_count_type": stage_count_type, + "kernel_schedule": KernelScheduleTag[kschedule], + "epilogue_schedule": EpilogueScheduleTag[eschedule], + "tile_scheduler": TileSchedulerTag[tschedule] + } + if hasattr(operation.epilogue_functor, "visitor"): + callback_name, callback_decl = operation.epilogue_functor.emit(operation) + values["callback_name"] = callback_name + values["callback_decl"] = callback_decl + return SubstituteTemplate(self.gemm_template_kernel_visitor, values) + + else: + values["epilogue_functor"] = operation.epilogue_functor.emit() + return SubstituteTemplate(gemm_template, values) + + +################################################################################################### +# Runtime module for GEMM Grouped +################################################################################################### + + +class GemmRTGrouped(GemmRTbase): + """ + GemmRTGrouped manages the CUTLASS runtime components + """ + + KernelTemplate = r""" +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix} op; + + op(params, *shared_storage); +} + """ + + HostTemplate = r""" + extern "C" { + + // precompute scheduling information + char * ${operation_name}_precompute(${operation_name}_base::Arguments const &args, int tile_count, size_t workspace_bytes) { + char* host_workspace = new char[workspace_bytes]; + ${operation_name}_base::ProblemVisitor::host_precompute( + args.host_problem_sizes, + args.problem_count, + args.threadblock_count, + (void*)host_workspace + ); + return host_workspace; + } + + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int tile_count, void* workspace=nullptr){ + ${operation_name}_base::Params* params; + params = new ${operation_name}_base::Params(*argument, workspace, tile_count); + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}_base::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape( + cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) { + return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape( + problem_size, tile_size, split_k_slices); + } + + dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) { + return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape); + } + } + """ + + def __init__(self, operation: "GemmOperation"): + super(GemmRTGrouped, self).__init__(operation) + self.extra_funcs = { + "precompute": None, + "get_tiled_shape": GemmCoord_, + "get_grid_shape": dim3_, + } + self.emitter = EmitGemmGroupedInstance("_type") + self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor) + self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p] + + def host_precompute(self, arguments, workspace_bytes): + self.precompute.argtype = [ + self.argtype[0], ctypes.c_int, ctypes.c_longlong] + self.precompute.restype = ctypes.POINTER(ctypes.c_byte * workspace_bytes) + + problem_info = self.precompute( + ctypes.byref(arguments.arguments), + arguments.total_tiles, + workspace_bytes) + problem_info_array = bytearray(problem_info.contents) + + # copy to device memory + return todevice(problem_info_array).ptr + + def plan(self, arguments): + return LaunchConfiguration( + [arguments.total_tiles, 1, 1], + [self.threads, 1, 1], + self.shared_memory_capacity, + ) + + def get_workspace_size(self, arguments): + if self.operation.precompute_mode == SchedulerMode.Device: + return 0 + elif self.operation.precompute_mode == SchedulerMode.Host: + total_tiles = arguments.total_tiles + entries_per_block = 1 + return 8 * entries_per_block * total_tiles # three int32_t + + +################################################################################ +# Runtime module for GEMM and grouped GEMM +################################################################################ + + +class GemmOperationBase: + """ + CUTLASS GEMM operation + """ + + def __init__( + self, gemm_kind, arch, tile_description: TileDescription, + A: TensorDescription, B: TensorDescription, C: TensorDescription, + epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, + api=ApiVersion.v2x, emission_type=EmissionType.Kernel, **kwargs): + self.operation_kind: OperationKind = OperationKind.Gemm + self.arch: int = arch + self.tile_description: TileDescription = tile_description + self.gemm_kind: GemmKind = gemm_kind + + self.api = api + self.prefix = "3x" if self.api == ApiVersion.v3x else "" + self.emission_type = emission_type + + # Optionally swap the TensorDescriptions for operands A and B and transpose their + # layouts. This is needed to mimic the transpose performed by device::GemmUniversal. + # The code below uses deep copy to avoid overwritting the original TensorDescription + self.switched = (self.api != ApiVersion.v3x and + self.emission_type == EmissionType.Kernel and + C.layout == LayoutType.ColumnMajor) + + self.A, self.B, self.C = GemmOperationBase.get_operands(A, B, C, self.switched) + + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + if "direct_store" in kwargs: + self.direct_store = kwargs["direct_store"] + else: + self.direct_store = False + + @staticmethod + def get_operands(A: TensorDescription, B: TensorDescription, C: TensorDescription, swap: bool): + """ + Makes copies of A, B, and C, and possibly transposes their order. If ``swap`` is set, + A and B are swapped, and the layout of A, B, and C are transposed. + + :param A: description of operand A + :type A: TensorDescription + :param B: description of operand B + :type B: TensorDescription + :param C: description of operand C + :type C: TensorDescription + + :return: descriptions of operands A, B, and C + :rtype: tuple[TileDescription] + """ + if swap: + A_out = copy.deepcopy(B) + B_out = copy.deepcopy(A) + C_out = copy.deepcopy(C) + A_out.layout = transpose_layout(A_out.layout) + B_out.layout = transpose_layout(B_out.layout) + C_out.layout = transpose_layout(C_out.layout) + else: + A_out = copy.deepcopy(A) + B_out = copy.deepcopy(B) + C_out = copy.deepcopy(C) + return A_out, B_out, C_out + + def run(self, arguments: GemmArguments) -> cuda.CUresult: + """ + Configure and launch the cuda kernel with input arguments + """ + if self.emission_type == EmissionType.Device: + raise Exception('Running a kernel via PyCUTLASS is only enabled with emission type "Kernel"') + + err = self.rt_module.run( + arguments.host_workspace, + arguments.device_workspace, + arguments.launch_config, + arguments.stream + ) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + return err + + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32, + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + def is_planar_complex(self): + return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) + + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + def core_name(self): + """The basic operation kind is prefixed with a letter indicating the accumulation type.""" + + inst_shape = "" + inst_operation = "" + intermediate_type = "" + + math_operations_map = { + MathOperation.xor_popc: "xor", + } + + if (self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp): + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else "" + + if self.tile_description.math_instruction.instruction_shape is not None: + if self.api == ApiVersion.v3x and self.arch >= 90: + inst_shape = "%dx%dx%d" % tuple( + self.tile_description.math_instruction.instruction_shape) + else: + inst_shape = "%d%d%d" % tuple( + self.tile_description.math_instruction.instruction_shape) + else: + inst_shape = "Default" + inst_shape += math_op_string + + if (self.tile_description.math_instruction.element_a != self.A.element and + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) + + def extended_name(self): + """Append data types if they differ from compute type.""" + if self.is_complex(): + extended_name = "${core_name}" + else: + if (self.C.element != self.tile_description.math_instruction.element_accumulator and + self.A.element != self.tile_description.math_instruction.element_accumulator): + extended_name = "${element_c}_${core_name}_${element_a}" + elif (self.C.element == self.tile_description.math_instruction.element_accumulator and + self.A.element != self.tile_description.math_instruction.element_accumulator): + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), + }) + + return extended_name + + def extended_name_3x(self): + """Generates a string representing the MMA atom. Assumes accumulator type is C type.""" + extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_a=DataTypeNames[self.A.element], + element_b=DataTypeNames[self.B.element], + element_acc=DataTypeNames[self.accumulator_type()], + element_c=DataTypeNames[self.C.element], + element_d=DataTypeNames[self.epilogue_functor.element_output], + core_name=self.core_name()) + return extended_name + + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + # Generates a short string representing the ABC layout tags (e.g. ntn or tnn) + def layout_name_3x(self): + if self.is_complex() or self.is_planar_complex(): + return "{}{}{}".format( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], + ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) + else: + return "{}{}{}".format( + ShortLayoutTypeNames[self.A.layout], + ShortLayoutTypeNames[self.B.layout], + ShortLayoutTypeNames[self.C.layout]) + + # Generates a short string representing underlying kernel schedule type + def kernel_schedule_name_3x(self): + if self.tile_description.kernel_schedule is None: + return KernelScheduleSuffixes[KernelScheduleType.ScheduleAuto] + else: + return KernelScheduleSuffixes[self.tile_description.kernel_schedule] + + # Generates a short string representing underlying epilogue schedule type + def epilogue_schedule_name_3x(self): + if self.tile_description.epilogue_schedule is None: + return EpilogueScheduleSuffixes[EpilogueScheduleType.ScheduleAuto] + else: + return EpilogueScheduleSuffixes[self.tile_description.epilogue_schedule] + + def procedural_name(self): + """The full procedural name indicates architecture, extended name, tile size, and layout.""" + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + if self.api == ApiVersion.v3x and self.arch >= 90: + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}" + return kernel_name_template.format( + p=self.prefix, + ar=self.arch, + op=opcode_class_name, + ex=self.extended_name_3x(), + tbm=self.tile_description.threadblock_shape[0], + tbn=self.tile_description.threadblock_shape[1], + tbk=self.tile_description.threadblock_shape[2], + cm=self.tile_description.cluster_shape[0], + cn=self.tile_description.cluster_shape[1], + ck=self.tile_description.cluster_shape[2], + l=self.tile_description.stages, + s=self.layout_name_3x(), + al=str(self.A.alignment), + k=self.kernel_schedule_name_3x(), + e=self.epilogue_schedule_name_3x() + ) + else: + threadblock = self.tile_description.procedural_name_2x() + return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( + p=self.prefix, + op=opcode_class_name, + ex=self.extended_name(), + tb=threadblock, + l=self.layout_name(), + a=str(self.A.alignment) + ) + + def configuration_name(self): + """The full procedural name indicates architecture, extended name, tile size, and layout.""" + return self.procedural_name() + + +class GemmOperationUniversal(GemmOperationBase): + def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, + epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs): + api = api_version(arch, tile_description.math_instruction.opcode_class, A.element) + super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description, + A, B, C, epilogue_functor, swizzling_functor, + api=api, **kwargs, ) + if api == ApiVersion.v3x: + if swizzling_functor == SwizzlingFunctor.StreamK: + raise Exception("Stream K swizzle functor is currently only supported for CUTLASS 2.x kernels") + self.rt_module = GemmRTUniversal3x(self) + else: + if swizzling_functor == SwizzlingFunctor.StreamK: + self.rt_module = GemmRTUniversalStreamK(self) + else: + self.rt_module = GemmRTUniversal(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type + + def device_op(self): + """ + Returns a new GemmOperationUniversal object that is constructed with emission type + ``EmissionType.Device``. Since the device-emitted kernel does not require swapping, + any swappng performed by the kernel-emitted operation is reversed. + + :return: operation ready for device-level code emission + :rtype: GemmUniversalOperation + """ + A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched) + return GemmOperationUniversal(self.arch, self.tile_description, A, B, C, + self.epilogue_functor, self.swizzling_functor, + emission_type=EmissionType.Device, direct_store=self.direct_store) + + +class GemmOperationGrouped(GemmOperationBase): + def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, + epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs): + super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description, + A, B, C, epilogue_functor, swizzling_functor, **kwargs) + assert "precompute_mode" in kwargs.keys(), "missing keyword arguement 'precompute_mode'." + self.precompute_mode = kwargs["precompute_mode"] + self.rt_module = GemmRTGrouped(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type + + def device_op(self): + """ + Returns a new GemmOperationGrouped object that is constructed with emission type + ``EmissionType.Device``. Since the device-emitted kernel does not require swapping, + any swappng performed by the kernel-emitted operation is reversed. + + :return: operation ready for device-level code emission + :rtype: GemmOperationGrouped + """ + A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched) + return GemmOperationGrouped( + self.arch, self.tile_description, A, B, C, self.epilogue_functor, + self.swizzling_functor, emission_type=EmissionType.Device, + direct_store=self.direct_store, precompute_mode=self.precompute_mode, ) + + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + + +class EmitGemmUniversalInstance: + """Responsible for emitting a CUTLASS template definition""" + + def __init__( + self, + operation_suffix="", + direct_store=False + ): + self.operation_suffix = operation_suffix + self.direct_store = direct_store + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm_coord.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/device/gemm_universal_adapter.h", + "cutlass/gemm/kernel/default_gemm_universal.h", + ] + if self.direct_store: + self.includes.append( + "cutlass/epilogue/threadblock/default_epilogue_direct_store.h" + ) + self.gemm_template_kernel = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + self.gemm_template_device = """ +// Gemm operator ${operation_name} +using DeviceKernel = + typename cutlass::gemm::device::GemmUniversal< + // Data type and layout of operand A + ${element_a}, ${layout_a}, + // Data type and layout of operand B + ${element_b}, ${layout_b}, + // Data type and layout of operand C + ${element_c}, ${layout_c}, + // Data type of accumulator + ${element_accumulator}, + // Class of operation + ${opcode_class}, + // Compute capability of the target kernel + ${arch}, + // Threadblock tile shape + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + // Warp tile shape + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + // Instruction shape + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + // Epilogue functor + ${epilogue_functor}, + // Swizzling function + ${swizzling_functor}, + // Number of pipeline stages + ${stages}, + // Alignment of operands A and B + ${align_a}, ${align_b}, + // Type of math operation + ${math_operation}, + // Complex transform types of operands A and B + ${transform_a}, ${transform_b} + >; +""" + self.gemm_template_direct_store = """ +// Gemm operator ${operation_name} +using ${operation_name}_default = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +using ${operation_name}_base = + cutlass::gemm::kernel::GemmUniversal< + ${operation_name}_default::Mma, + cutlass::epilogue::threadblock::DefaultEpilogueDirectStore< + ${operation_name}_default::Epilogue + >::Epilogue, + ${operation_name}_default::ThreadblockSwizzle + >; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_kernel_visitor = """ + +using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + ${element_c}, + ${align_c}, + ${epilogue_stages} /* epilogue stages */ +>; + +${callback_decl} + +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_accumulator}, + ${element_epilogue}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${callback_name}, + ${swizzling_functor}, + ${stages}, + ${math_operation}, + ${epilogue_stages} /* epilogue stages */ +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + def emit(self, operation): + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + + if operation.emission_type == EmissionType.Kernel: + if self.direct_store: + gemm_template = self.gemm_template_direct_store + else: + gemm_template = self.gemm_template_kernel + else: + gemm_template = self.gemm_template_device + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[instance_layout_A], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[instance_layout_B], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[instance_layout_C], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), + "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), + "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation], + } + + if hasattr(operation.epilogue_functor, "visitor"): + self.includes += [ + "cutlass/epilogue/threadblock/fusion/visitors.hpp", + "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" + ] + callback_name, callback_decl = operation.epilogue_functor.emit(operation) + values["callback_name"] = callback_name + values["callback_decl"] = callback_decl + values["align_c"] = str(operation.C.alignment) + values["element_epilogue"] = DataTypeTag[operation.epilogue_functor.element_epilogue] + if hasattr(operation.epilogue_functor, "epilogue_stages"): + epilogue_stages = operation.epilogue_functor.epilogue_stages + else: + epilogue_stages = 1 + values["epilogue_stages"] = str(epilogue_stages) + return SubstituteTemplate(self.gemm_template_kernel_visitor, values) + else: + values["epilogue_functor"] = operation.epilogue_functor.emit() + return SubstituteTemplate(gemm_template, values) + + +class EmitGemmGroupedInstance: + """Responsible for emitting a CUTLASS template definition""" + + def __init__(self, operation_suffix=""): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/kernel/gemm_grouped.h", + "cutlass/gemm/kernel/default_gemm_grouped.h", + ] + self.gemm_template_kernel = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmGrouped< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${precompute_mode}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_device = ( + self.gemm_template_kernel + + """ +using DeviceKernel = cutlass::gemm::device::GemmGrouped<${operation_name}_base>; +""" + ) + + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmGrouped<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + def emit(self, operation): + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + + # Support built-in epilogue functors or user-defined functions + epilogue_functor = operation.epilogue_functor.emit() + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[instance_layout_A], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[instance_layout_B], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[instance_layout_C], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), + "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), + "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), + "epilogue_functor": epilogue_functor, + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "precompute_mode": SchedulerModeTag[operation.precompute_mode], + "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation], + } + + if operation.emission_type == EmissionType.Kernel: + gemm_template = self.gemm_template_kernel + else: + gemm_template = self.gemm_template_device + + return SubstituteTemplate(gemm_template, values) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py new file mode 100644 index 0000000000000000000000000000000000000000..a77b302dcccf330cc0e0f9b3f1290ab7030c5932 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py @@ -0,0 +1,509 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Common data types and string names/tags for them +""" + +import enum + +from cutlass_library import ( + ComplexTransform, + DataType, + DataTypeSize, + EpilogueScheduleType, + KernelScheduleSuffixes, + KernelScheduleType, + MathOperation, + OpcodeClass, + TileSchedulerType +) + + +# The following block implements enum.auto() for Python 3.5 variants that don't include it such +# as the default 3.5.2 on Ubuntu 16.04. +# +# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility + +try: + from enum import auto as enum_auto +except ImportError: + __cutlass_library_auto_enum = 0 + + def enum_auto() -> int: + global __cutlass_library_auto_enum + i = __cutlass_library_auto_enum + __cutlass_library_auto_enum += 1 + return i + + +class DataTypeSizeBytes: + """ + Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the + data type key is less than a full byte or a non-integer number of bytes. + """ + + @staticmethod + def __class_getitem__(datatype): + """ + Returns the number of bytes in size the data type is. Raises an exception if the data type + is either less than a full byte or a non-integer number of bytes in size. + + :param datatype: data type to query + + :return: number of bytes the data type occupies + :rtype: int + """ + bits = DataTypeSize[datatype] + if bits < 8: + raise Exception( + f"Data type {datatype} is less than one byte in size." + ) + elif bits % 8 != 0: + raise Exception( + f"Data type datatype is not an integer number of bytes." + ) + return bits // 8 + + +class SchedulerMode(enum.Enum): + Device = enum_auto() + Host = enum_auto() + + +SchedulerModeTag = { + SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly", + SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute", +} + + +ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"} + + +class FunctionalOp(enum.Enum): + AtomicAdd = enum_auto() + AtomicMaximum = enum_auto() + Divides = enum_auto() + Maximum = enum_auto() + Minimum = enum_auto() + Minus = enum_auto() + Multiplies = enum_auto() + MultiplyAdd = enum_auto() + Plus = enum_auto() + Exp = enum_auto() + + +FunctionalOpTag = { + FunctionalOp.AtomicAdd: "cutlass::atomic_add", + FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum", + FunctionalOp.Divides: "cutlass::divides", + FunctionalOp.Maximum: "cutlass::maximum", + FunctionalOp.Minimum: "cutlass::minimum", + FunctionalOp.Minus: "cutlass::minus", + FunctionalOp.Multiplies: "cutlass::multiplies", + FunctionalOp.MultiplyAdd: "cutlass::multiply_add", + FunctionalOp.Plus: "cutlass::plus", + FunctionalOp.Exp: "cutlass::fast_exp_op", +} + + +class ActivationOp(enum.Enum): + DGelu = enum_auto() + Gelu = enum_auto() + GeluTaylor = enum_auto() + HardSwish = enum_auto() + Identity = enum_auto() + LeakyReLU = enum_auto() + ReLU = enum_auto() + Sigmoid = enum_auto() + SiLU = enum_auto() + Tanh = enum_auto() + + +ActivationOpTag = { + ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU", + ActivationOp.Gelu: "cutlass::epilogue::thread::GELU", + ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor", + ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish", + ActivationOp.Identity: "cutlass::epilogue::thread::Identity", + ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU", + ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu", + ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid", + ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu", + ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh", +} + + +def op_tag(op) -> str: + """ + Dispatches `op` to the appropriate *Tag dictionary depending on whether + `op` is an ActivationOp or FunctionalOp. This is useful for cases in which + either type can be used. + + :param op: operation to emit a tag for + :type op: ActivationOp | FunctionalOp + + :return: tag corresponding to op + :rtype: str + """ + if isinstance(op, ActivationOp): + return ActivationOpTag[op] + elif isinstance(op, FunctionalOp): + return FunctionalOpTag[op] + else: + raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.") + + +class FloatRoundStyle(enum.Enum): + ToNearest = enum_auto() + ToNearestSatfinite = enum_auto() + Indeterminate = enum_auto() + TowardZero = enum_auto() + TowardInfinity = enum_auto() + TowardNegInfinity = enum_auto() + HalfUlpTruncDntz = enum_auto() + HalfUlpTruncate = enum_auto() + + +FloatRoundStyleTag = { + FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest", + FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite", + FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate", + FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero", + FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity", + FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity", + FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz", + FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate", +} + + +class MathInstruction: + """ + Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel + """ + + def __init__( + self, + instruction_shape, + element_a, + element_b, + element_accumulator, + opcode_class=OpcodeClass.Simt, + math_operation=MathOperation.multiply_add, + ): + """ + :param instruction_shape: size of the [M, N, K] dimensions of the instruction + :type instruction_shape: list or tuple + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_accumulator: data type used in accumulation + :param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core) + :type opcode_class: cutlass_library.library.OpcodeClass + :param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate) + :type math_operation: MathOperation + """ + self.instruction_shape = instruction_shape + self.element_a = element_a + self.element_b = element_b + self.element_accumulator = element_accumulator + self.opcode_class = opcode_class + self.math_operation = math_operation + + +def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule): + blackwell_threadblock_shape = tile_description.threadblock_shape + is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule]) + if cluster_shape[0] > 0: + blackwell_threadblock_shape = [ + tile_description.threadblock_shape[0] // cluster_shape[0], + tile_description.threadblock_shape[1] // cluster_shape[1], + tile_description.threadblock_shape[2] // cluster_shape[2] + ] + if is_2sm: + blackwell_threadblock_shape[0] *= 2 + else: + blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape + return blackwell_threadblock_shape, is_2sm + + +class TileDescription: + """ + Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes, + stage count, and math instruction specification + """ + + def __init__( + self, + threadblock_shape, + stages, + warp_count, + math_instruction, + cluster_shape=[1, 1, 1], + kernel_schedule: KernelScheduleType = None, + epilogue_schedule: EpilogueScheduleType = None, + tile_scheduler: TileSchedulerType = None + ): + """ + :param threadblock_shape: shape of a threadblock tyle + :type threadblock_shape: list or tuple + :param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum + number of stages that can be supported for an operation on a given architecture will be computed at a later time + :type stages: int or None + :param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile + :type warp_count: list, tuple, or None + :param math_instruction: specification of the instruction type and shape to be performed and the types of its operands + :type math_instruction: MathInstruction + :param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster + :param kernel_schedule: type of kernel schedule to use (only available for SM90+) + :type kernel_schedule: cutlass_library.KernelScheduleType + :param epilogue_schedule: type of epilogue schedule to use (only available for SM90+) + :type epilogue_schedule: cutlass_library.EpilogueScheduleType + :param tile_scheduler: type of tile scheduler to use (only available for SM90+) + :type tile_scheduler: cutlass_library.TileSchedulerType + """ + if ((kernel_schedule is None and epilogue_schedule is not None) or + (kernel_schedule is not None and epilogue_schedule is None)): + raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.") + + self.threadblock_shape = threadblock_shape + self.cluster_shape = cluster_shape + self.kernel_schedule = kernel_schedule + self.epilogue_schedule = epilogue_schedule + self.tile_scheduler = tile_scheduler + self.stages = stages + + self.math_instruction = math_instruction + self.instruction_shape = math_instruction.instruction_shape + + # Number of warps along x, y, z directions + self.warp_count = warp_count + + self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule) + + def clone_and_update(self, td: dict): + attrs = { + "cluster_shape": None, + "threadblock_shape": None, + "warp_count": None, + "stages": None, + "instruction_shape": None, + "kernel_schedule": None, + "epilogue_schedule": None, + "tile_scheduler": None + } + for key in attrs.keys(): + if key in td.keys(): + attrs[key] = td[key] + else: + attrs[key] = getattr(self, key) + + attrs["math_instruction"] = MathInstruction( + attrs["instruction_shape"], + self.math_instruction.element_a, + self.math_instruction.element_b, + self.math_instruction.element_accumulator, + self.math_instruction.opcode_class, + self.math_instruction.math_operation + ) + + # Remove the instruction shape + del attrs["instruction_shape"] + + return TileDescription(**attrs) + + @property + def num_threads(self): + """ + Returns the number of threads in the threadblock + + :return: number of threads in the threadblock + :rtype: int or None (if warp count is None) + """ + if self.warp_count is not None: + threads = 32 + for cnt in self.warp_count: + threads *= cnt + return threads + return None + + def procedural_name(self): + """ + Returns a name identifying the tile description + + :return: name identifying the tile description + :rtype: int + """ + emit_stages = 0 if self.stages is None else self.stages + name = "%dx%dx%d_%dx%d_%dx%d" % ( + self.cluster_shape[0], + self.cluster_shape[1], + self.cluster_shape[2], + self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + emit_stages + ) + + return name + + def procedural_name_2x(self): + """ + Returns a name identifying the tile description + + :return: name identifying the tile description + :rtype: int + """ + return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + + def __str__(self): + """ + Returns a string with containing each of the tile description's values + + :return: contents of tile description + :rtype: str + """ + if self.kernel_schedule is not None: + kschedule = self.kernel_schedule + else: + kschedule = KernelScheduleType.ScheduleAuto + + if self.epilogue_schedule is not None: + eschedule = self.epilogue_schedule + else: + eschedule = EpilogueScheduleType.ScheduleAuto + + if self.tile_scheduler is not None: + tschedule = self.tile_scheduler.name + else: + tschedule = "None" + return f""" +{{ + ClusterShape: {self.cluster_shape} + ThreadblockShape: {self.threadblock_shape} + WarpCount: {self.warp_count} + Stages: {self.stages if self.stages is not None else 'Auto'} + InstructionShape: {self.math_instruction.instruction_shape} + Kernel schedule: {kschedule.name} + Epilogue schedule: {kschedule.name} + TileScheduler: {tschedule} +}}""" + + +class TensorDescription: + def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none): + self.element = element + self.layout = layout + if element != DataType.void: + self.alignment = min(128 // DataTypeSize[self.element], alignment) + else: + self.alignment = alignment + self.complex_transform = complex_transform + + +def CalculateSmemUsagePerStage(operation): + """ + Returns the amount of shared memory in bytes consumed in a single stage of a kernel. + + :param op: operation for which the maximum stages should be computed. If stages are + set via the `op.tile_description.stages` parameter, this setting is ignored + in the present calculation + :type op: cutlass_cppgen.backend.Operation + + :return: number of bytes of shared memory consumed by a single stage + :rtype: int + """ + m, n, k = operation.tile_description.threadblock_shape + + if operation.operation_kind == OperationKind.Gemm: + stage_barrier_bytes = 32 + return ( + (DataTypeSize[operation.A.element] * m * k // 8) + + (DataTypeSize[operation.B.element] * k * n // 8) + + stage_barrier_bytes + ) + else: + raise Exception("Unsupported operation kind {}.".format(operation.operation_kind)) + + +def CalculateSmemUsage(operation): + """ + Returns the amount of shared memory in bytes consumed by a kernel. + + :param op: operation for which the maximum stages should be computed. If stages are + set via the `op.tile_description.stages` parameter, this setting is ignored + in the present calculation + :type op: cutlass_cppgen.backend.Operation + + :return: int + """ + return operation.tile_description.stages * CalculateSmemUsagePerStage(operation) + + +class ApiVersion(enum.Enum): + """ + Differentiate between CUTLASS 2.x and 3.x API versions + """ + + v2x = enum_auto() + v3x = enum_auto() + + +def api_version(arch, opclass, dtype): + """ + Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x + or 3.x for code emission. + + :param arch: compute capability of device on which to run + :type arch: int + :param opclass: class of the operation being performed + :type opclass: cutlass_library.OpcodeClass + :param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same) + :type dtype: cutlass_library.DataType + + :return: API version to be used in code emission + :rtype: ApiVersion + """ + if (arch in [90, 100, 101, 103] and + opclass == OpcodeClass.TensorOp and + (dtype != DataType.f64)): + return ApiVersion.v3x + else: + return ApiVersion.v2x + + +class EmissionType(enum.Enum): + """ + Tags for whether to emit a kernel- or device-level operation + """ + + Kernel = enum_auto() + Device = enum_auto() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..30e6bb3108ddd30e3776cf92b0671fce4fae5a93 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py @@ -0,0 +1,121 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import numpy as np + +import cutlass_cppgen +from cutlass_cppgen.utils.datatypes import is_numpy_tensor +from cutlass_cppgen.utils.lazy_import import lazy_import + +if cutlass_cppgen.use_rmm: + import rmm +else: + cudart = lazy_import("cuda.cudart") + + +class PoolMemoryManager: + def __init__(self, init_pool_size: int, max_pool_size: int) -> None: + self.pool = rmm.mr.PoolMemoryResource( + rmm.mr.CudaMemoryResource(), + initial_pool_size=init_pool_size, + maximum_pool_size=max_pool_size + ) + self.mr = rmm.mr.TrackingResourceAdaptor(self.pool) + rmm.mr.set_current_device_resource(self.mr) + + def pool_size(self): + return self.pool.pool_size() + + +class DevicePtrWrapper: + """ + Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer + (at least in terms of the interface used by the CUTLASS Python interface) + """ + def __init__(self, dev_ptr): + self.dev_ptr = dev_ptr + + @property + def ptr(self): + return self.dev_ptr + + +def _todevice(host_data): + """ + Helper for transferring host data to device memory + """ + if cutlass_cppgen.use_rmm: + return rmm.DeviceBuffer.to_device(host_data.tobytes()) + else: + nbytes = len(host_data.tobytes()) + dev_ptr_wrapper = device_mem_alloc(nbytes) + err, = cudart.cudaMemcpy( + dev_ptr_wrapper.ptr, + host_data.__array_interface__['data'][0], + nbytes, + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice + ) + if err != cudart.cudaError_t.cudaSuccess: + raise Exception(f"cudaMemcpy failed with error {err}") + return dev_ptr_wrapper + + +def todevice(host_data, dtype=np.float32): + """ + Pass the host_data to device memory + """ + if isinstance(host_data, list): + return _todevice(np.array(host_data, dtype=dtype)) + elif is_numpy_tensor(host_data): + return _todevice(host_data) + + +def device_mem_alloc(size): + if cutlass_cppgen.use_rmm: + return rmm.DeviceBuffer(size=size) + else: + err, ptr = cudart.cudaMalloc(size) + if err != cudart.cudaError_t.cudaSuccess: + raise Exception(f"cudaMalloc failed with error {err}") + return DevicePtrWrapper(ptr) + + +def align_size(size, alignment=256): + return ((size + alignment - 1) // alignment) * alignment + + +def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34): + if cutlass_cppgen.use_rmm: + memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size) + return memory_pool + else: + return None diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..10ee67bc6f547d079b6d990e7abea69a16549c16 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py @@ -0,0 +1,140 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ctypes +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") + +from cutlass_cppgen.backend.utils.device import device_cc + +_supports_cluster_launch = None + + +def supports_cluster_launch(): + from cuda import __version__ + _version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")] + global _supports_cluster_launch + if _supports_cluster_launch is None: + major, minor = _version_splits[0], _version_splits[1] + _supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8)) + return _supports_cluster_launch + + +class LaunchConfiguration: + def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0): + self.grid = grid + self.block = block + self.shared_memory_capacity = smem + + +class ExecutableOperation: + def __init__(self, operation): + self.operation = operation + self.module = None + self.kernel = None + + def name(self): + return self.operation.procedural_name() + + def emit(self): + return "" + + def can_implement(self, configuration, arguments): + raise NotImplementedError() + + def get_host_workspace_size(self, arguments): + raise NotImplementedError() + + def get_device_workspace_size(self, arguments): + raise NotImplementedError() + + def plan(self, arguments): + raise NotImplementedError() + + def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None): + raise NotImplementedError() + + def run_with_clusters(self, launch_config, kernel_params, stream=None): + if not stream: + stream = cuda.CUstream(0) + if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"): + attr = cuda.CUlaunchAttribute() + attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape + attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attrs = [attr] + + # Allow for non-portable cluster sizes + err, = cuda.cuFuncSetAttribute( + self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1) + if err != cuda.CUresult.CUDA_SUCCESS: + return err + else: + attrs = [] + + config = cuda.CUlaunchConfig() + config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid + config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block + config.blockDimZ = launch_config.block[2] + config.sharedMemBytes = launch_config.shared_memory_capacity + config.hStream = stream + config.attrs = attrs + config.numAttrs = len(attrs) + + err, = cuda.cuLaunchKernelEx( + config, f=self.kernel, kernelParams=kernel_params, extra=0) + return err + + def run_without_clusters(self, launch_config, kernel_params, stream=None): + if not stream: + stream = cuda.CUstream(0) + err, = cuda.cuLaunchKernel( + self.kernel, + launch_config.grid[0], launch_config.grid[1], launch_config.grid[2], + launch_config.block[0], launch_config.block[1], launch_config.block[2], + launch_config.shared_memory_capacity, + stream, + kernel_params, + 0) + + return err + + def run(self, host_workspace, device_workspace, launch_config, stream=None): + if not stream: + stream = cuda.CUstream(0) + cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace) + packed = (ctypes.c_void_p * 1)() + packed[0] = ctypes.addressof(cArg) + + if supports_cluster_launch(): + return self.run_with_clusters(launch_config, packed, stream) + else: + return self.run_without_clusters(launch_config, packed, stream) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..535cea2cb2a23ccbb29cce7233f42147ed2ea5eb --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py @@ -0,0 +1,455 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +from __future__ import annotations + +import ctypes +from typing import Union + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +import numpy as np + +from cutlass_library import ( + DataTypeNames, + DataTypeSize, + DataTypeTag, + LayoutType, + SubstituteTemplate +) + +import cutlass_cppgen +from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params +from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend +from cutlass_cppgen.backend.library import TensorDescription +from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper +from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass_cppgen.shape import MatrixCoord +from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor + + +class ReductionOperation: + pass + + +class ReductionArguments: + """ + Arguments of reduction + """ + + def __init__( + self, + operation: ReductionOperation, + problem_size: "list[int]", + partitions: int, + workspace: cuda.CUdeviceptr, + destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", + source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", + **kwargs, + ) -> None: + # tensor_C can be interpreted as the bias with bias=True in keyword args + if "bias" in kwargs.keys(): + self.bias = kwargs["bias"] + else: + # by default, tensor_C is not bias + self.bias = False + if "stream" in kwargs.keys(): + self.stream = kwargs["stream"] + else: + self.stream = cuda.CUstream(0) + + self.operation = operation + self.ptr_workspace = workspace + + # number of split-k partitions + self.partitions = partitions + + if is_numpy_tensor(destination): + self.host_D = destination + self.destination_buffer = NumpyFrontend.argument(destination, True) + self.source_buffer = NumpyFrontend.argument(source, False) + self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr) + self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr) + elif is_torch_tensor(destination): + self.ptr_destination = TorchFrontend.argument(destination) + self.ptr_source = TorchFrontend.argument(source) + elif isinstance(destination, cuda.CUdeviceptr): + self.ptr_destination = destination + self.ptr_source = source + else: + raise TypeError("unknown Type") + + self.problem_size = MatrixCoord_(problem_size[0], problem_size[1]) + + self.partition_stride = ( + problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8 + ) + + if "output_op" in kwargs.keys(): + self.output_op = kwargs["output_op"] + else: + self.output_op = self.operation.epilogue_type(1.0, 0.0) + + self.get_arguments() + + @staticmethod + def get_tensor_ref( + extent: "tuple[int]", + device_ptr: cuda.CUdeviceptr, + layout: LayoutType, + ): + if layout == LayoutType.RowMajor: + return TensorRef2D_(int(device_ptr), extent[1]) + else: + raise ValueError(f"Unknown layout type {layout}") + + def get_arguments(self): + ref_workspace = ReductionArguments.get_tensor_ref( + extent=[ + self.problem_size.row, + self.problem_size.column, + ], + device_ptr=self.ptr_workspace, + layout=LayoutType.RowMajor, + ) + if self.bias: + ref_source = ReductionArguments.get_tensor_ref( + extent=[0, 0], + device_ptr=self.ptr_source, + layout=LayoutType.RowMajor, + ) + else: + ref_source = ReductionArguments.get_tensor_ref( + extent=[ + self.problem_size.row, + self.problem_size.column, + ], + device_ptr=self.ptr_source, + layout=LayoutType.RowMajor, + ) + + ref_destination = ReductionArguments.get_tensor_ref( + extent=[ + self.problem_size.row, + self.problem_size.column, + ], + device_ptr=self.ptr_destination, + layout=LayoutType.RowMajor, + ) + + self.c_arguments = self.operation.argument_type( + self.problem_size, + self.partitions, + self.partition_stride, + ref_workspace, + ref_destination, + ref_source, + self.output_op, + ) + + params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments)) + self.host_workspace = bytearray(params_.contents) + + def sync(self): + (err,) = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + + if hasattr(self, "host_D"): + (err,) = cuda.cuMemcpyDtoH( + self.host_D, + self.ptr_destination, + self.host_D.size * self.host_D.itemsize, + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + self.free() + + def free(self): + """ + Frees allocated device-side memory + """ + # Free any device memory allocated manually + if not cutlass_cppgen.use_rmm: + for attr in ["destination_buffer", "source_buffer"]: + if hasattr(self, attr): + buf = getattr(self, attr) + if isinstance(buf, DevicePtrWrapper): + err, = cudart.cudaFree(buf.ptr) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + del buf + + +class ReductionRT(ExecutableOperation): + """ + ReductionRT manages the CUTLASS runtime components for reduction + """ + + KernelTemplate = r""" +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix} op; + + op(params, *shared_storage); +} + """ + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){ + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++) + output[i] = bytes[i]; + + return output; + } +} + """ + + def __init__(self, operation: ReductionOperation): + super().__init__(operation) + + self.operation: ReductionOperation = operation + self.emitter = EmitReductionInstance("_type") + + self.elements_per_access = self.operation.count + ( + self.argument_type, + self.epilogue_type, + ) = get_reduction_params(operation.epilogue_functor) + self.argtype = [ctypes.POINTER(self.argument_type)] + + def emit(self): + return self.emitter.emit(self.operation) + + def plan(self, arguments: ReductionArguments): + block_shape = [ + self.operation.shape.column // self.elements_per_access, + self.operation.shape.row, + 1, + ] + grid_shape = [ + (arguments.problem_size.row + self.operation.shape.row - 1) + // self.operation.shape.row, + (arguments.problem_size.column + self.operation.shape.column - 1) + // self.operation.shape.column, + 1, + ] + return LaunchConfiguration( + grid_shape, + block_shape, + self.shared_memory_capacity, + ) + + def initialize(self): + (err,) = cuda.cuFuncSetAttribute( + self.kernel, + attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + value=self.shared_memory_capacity, + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error: {err}") + + +class ReductionOperation: + """ + CUTLASS reduction Operation + """ + + def __init__( + self, + shape: MatrixCoord, + C: TensorDescription, + element_accumulator, + element_workspace=None, + element_compute=None, + epilogue_functor=None, + count: int = 1, + partitions_per_stage: int = 4, + ) -> None: + self.shape = shape + self.epilogue_functor = epilogue_functor + self.element_accumulator = element_accumulator + + if element_workspace is None: + self.element_workspace = element_accumulator + else: + self.element_workspace = element_workspace + + if element_compute is None: + self.element_compute = element_accumulator + else: + self.element_compute = element_compute + + self.element_output = C.element + self.C: TensorDescription = C + + # Reduce op processing size + self.count: int = count + + # Number of partitions to reduce per stage + self.partitions_per_stage: int = partitions_per_stage + + self.rt_module: ReductionRT = ReductionRT(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type + + def extended_name(self): + extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}" + + return SubstituteTemplate( + extend_name, + { + "element_workspace": DataTypeNames[self.element_workspace], + "element_accumulator": DataTypeNames[self.element_accumulator], + "element_compute": DataTypeNames[self.element_compute], + "element_output": DataTypeNames[self.element_output], + }, + ) + + def configuration_name(self): + """The full procedural name indicates architecture, extended name, tile size""" + + configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}" + + threadblock = "%dx%d" % ( + self.shape.row, + self.shape.column, + ) + + return SubstituteTemplate( + configuration_name, + { + "extended_name": self.extended_name(), + "threadblock": threadblock, + }, + ) + + def procedural_name(self): + """The full procedural name indicates architeture, extended name, tile size""" + return self.configuration_name() + + def run(self, arguments: ReductionArguments) -> cuda.CUresult: + """ + Configure and launch the cuda kernel with input arguments + """ + launch_config = self.rt_module.plan(arguments) + + host_workspace = arguments.host_workspace + device_workspace = None + + err = self.rt_module.run( + host_workspace, + device_workspace, + launch_config, + arguments.stream + ) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + + return err + + +class EmitReductionInstance: + def __init__(self, operation_suffix="") -> None: + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/device/gemm_universal_adapter.h", + "cutlass/gemm/kernel/default_gemm_universal.h", + "cutlass/reduction/kernel/reduce_split_k.h", + "cutlass/reduction/thread/reduction_operators.h", + ] + self.template = """ +// Reduction kernel instance +using ${operation_name}_base = +typename cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<${shape_row}, ${shape_column}>, + ${epilogue_functor}, + cutlass::reduction::thread::ReduceAdd< + ${element_accumulator}, + ${element_output}, + ${count}>, + ${partition_per_stage}>; + +struct ${operation_name}${operation_suffix}: + public ${operation_name}_base { }; + """ + + def emit(self, operation: ReductionOperation): + vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128) + epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element] + + values = { + "operation_name": operation.configuration_name(), + "operation_suffix": self.operation_suffix, + "shape_row": str(operation.shape.row), + "shape_column": str(operation.shape.column), + "epilogue_functor": operation.epilogue_functor.emit(), + "element_output": DataTypeTag[operation.element_output], + "epilogue_vector_length": str(epilogue_vector_length), + "element_accumulator": DataTypeTag[operation.element_accumulator], + "element_compute": DataTypeTag[operation.element_compute], + "element_workspace": DataTypeTag[operation.element_workspace], + "count": str(operation.count), + "partition_per_stage": str(operation.partitions_per_stage), + } + + return SubstituteTemplate(self.template, values) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py new file mode 100644 index 0000000000000000000000000000000000000000..fffa03360f7e0eb2f3a2a20e5c8a4e04d009bee9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py @@ -0,0 +1,35 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]" + +Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0bae3bac1163c55a698dfc8722c62ac85cb25abf --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py @@ -0,0 +1,33 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed4096a6f4b772a58702c2f4b089cc32d707614 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py @@ -0,0 +1,126 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for interacting with the device +""" +from __future__ import annotations + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") + +import cutlass_cppgen +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor + + +def check_cuda_errors(result: list): + """ + Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise, + returns the result contained in the remaining fields of `result`. + + :param result: the results of the `cudart` method, consisting of an error code and any method results + :type result: list + + :return: non-error-code results from the `results` parameter + """ + # `result` is of the format : (cudaError_t, result...) + err = result[0] + if err.value: + raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err))) + + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + +def device_cc(device: int = -1) -> int: + """ + Returns the compute capability of the device with ID `device`. + + :param device: ID of the device to query + :type device: int + + :return: compute capability of the queried device (e.g., 80 for SM80) + :rtype: int + """ + if device == -1: + device = cutlass_cppgen.device_id() + + deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device)) + major = str(deviceProp.major) + minor = str(deviceProp.minor) + return int(major + minor) + + +def device_sm_count(device: int = -1): + if device == -1: + device = cutlass_cppgen.device_id() + err, device_sm_count = cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise Exception( + "Failed to retireve SM count. " + f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}" + ) + + return device_sm_count + + +def to_device_ptr(tensor) -> cuda.CUdeviceptr: + """ + Converts a tensor to a CUdeviceptr + + :param tensor: tensor to convert + :type tensor: np.ndarray | torch.Tensor | cp.ndarray | int + + :return: device pointer + :rtype: cuda.CUdeviceptr + """ + if is_numpy_tensor(tensor): + ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0]) + elif is_torch_tensor(tensor): + ptr = cuda.CUdeviceptr(tensor.data_ptr()) + elif is_cupy_tensor(tensor): + ptr = cuda.CUdeviceptr(int(tensor.data.ptr)) + elif isinstance(tensor, cuda.CUdeviceptr): + ptr = tensor + elif isinstance(tensor, int): + ptr = cuda.CUdeviceptr(tensor) + else: + raise NotImplementedError(tensor) + + return ptr diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e4121b59e57e26e8a32022916089e0916db4988 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py @@ -0,0 +1,33 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.emit.pytorch import pytorch diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py new file mode 100644 index 0000000000000000000000000000000000000000..58f94e15148f934c92318b586d63b669757ed5f0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py @@ -0,0 +1,267 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Common utilities for emitting CUTLASS kernels +""" + +import cutlass_cppgen + +# Strings used for printing information about the generation of emitted scripts +_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)" + + +_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR} +""" + + +_PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR} +""" + +_CUTLASS_KERNEL_ARGS_2x = """ + typename DeviceKernel::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, // problem size + 1, + {alpha, beta}, + A, B, C, D, + 0, 0, 0, 0, // batch strides + DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda + DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb + DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc + DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd + }; +""" + +_CUTLASS_KERNEL_ARGS_2x_STREAM_K = """ + typename DeviceKernel::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, // problem size + 1, + {alpha, beta}, + A, B, C, D, + 0, 0, 0, 0, // batch strides + DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda + DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb + DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc + DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd + -1 // avail_sms + }; +""" + +_CUTLASS_KERNEL_RUN_GEMM_2x = """ +using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; + +cutlass::Status ${name}_kernel_run(int M, int N, int K, + const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D, + ElementCompute alpha, ElementCompute beta) { + ${args} + size_t workspace_size = DeviceKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + DeviceKernel gemm_op; + cutlass::Status status = gemm_op.initialize(arguments, + workspace.get(), + nullptr); // CUDA stream + + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = gemm_op(); + return status; +} +""" + +_CUTLASS_KERNEL_RUN_GEMM_3x = """ +using StrideA = typename DeviceKernel::GemmKernel::StrideA; +using StrideB = typename DeviceKernel::GemmKernel::StrideB; +using StrideC = typename DeviceKernel::GemmKernel::StrideC; +using StrideD = typename DeviceKernel::GemmKernel::StrideD; + +using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; + +cutlass::Status ${name}_kernel_run( + int M, int N, int K, int L, + const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D, + ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) { + + typename DeviceKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, L}, // problem size + { + A, // ptrA + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A + B, // ptrB + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B + }, + { + {alpha, beta}, + C, // ptrC + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C + D, // ptrD + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D + }, + hw_info + }; + + size_t workspace_size = DeviceKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + DeviceKernel gemm_op; + cutlass::Status status = gemm_op.run(arguments, + workspace.get(), + nullptr); // CUDA stream + + return status; +} +""" + + +_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """ +using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; + +int threadblock_count = DeviceKernel::sufficient(); + +cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes, + DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D, + int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd, + ElementCompute alpha, ElementCompute beta) { + + typename DeviceKernel::Arguments arguments { + problem_sizes, + problem_count, + threadblock_count, + {alpha, beta}, + A, B, C, D, + lda, ldb, ldc, ldd + }; + + size_t workspace_size = DeviceKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + DeviceKernel gemm_op; + cutlass::Status status = gemm_op.initialize(arguments, + workspace.get(), + nullptr); // CUDA stream + + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = gemm_op(); + return status; +} +""" + + +_CUTLASS_KERNEL_RUN_CONV2D_2x = """ + +using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel; +namespace { +using TensorRefA = typename UnderlyingKernel::TensorRefA; +using TensorRefB = typename UnderlyingKernel::TensorRefB; +using TensorRefC = typename UnderlyingKernel::TensorRefC; +using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute; +} + +template +TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){ + cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord); + TensorRef tensor_ref(ptr, layout); + return tensor_ref; +} + +cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size, + UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B, + UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D, + ElementCompute alpha, ElementCompute beta, std::string split_k_mode, + cudaStream_t stream, int device_id=0) { + // create the tensor references + cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent( + cutlass::conv::Operator::k${conv_kind_name}, *problem_size + ); + cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent( + cutlass::conv::Operator::k${conv_kind_name}, *problem_size + ); + cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent( + cutlass::conv::Operator::k${conv_kind_name}, *problem_size + ); + + TensorRefA tensor_ref_A = get_tensor_ref(tensor_coord_A, A); + TensorRefB tensor_ref_B = get_tensor_ref(tensor_coord_B, B); + TensorRefC tensor_ref_C = get_tensor_ref(tensor_coord_C, C); + TensorRefC tensor_ref_D = get_tensor_ref(tensor_coord_C, D); + + cutlass::conv::SplitKMode mode; + if (split_k_mode == "serial") { + mode = cutlass::conv::SplitKMode::kSerial; + } else if (split_k_mode == "parallel") { + mode = cutlass::conv::SplitKMode::kParallel; + } else { + throw std::runtime_error("Invalid split_k_mode: " + split_k_mode); + } + + typename DeviceKernel::Arguments arguments{ + *problem_size, + tensor_ref_A, + tensor_ref_B, + tensor_ref_C, + tensor_ref_D, + {alpha, beta}, + mode + }; + + DeviceKernel implicit_gemm_op; + + size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); + + void* workspace_ptr = device_memory_allocation(workspace_size, device_id); + + cutlass::Status status = implicit_gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // + // Launch initialized CUTLASS kernel + // + status = implicit_gemm_op(stream); + + return status; +} +""" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..fe96f3ede11163da01520f972eb97282a2ab2b14 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py @@ -0,0 +1,936 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel. +If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method. + +Example usage with JIT compilation: + +.. highlight:: python +.. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor) + op = plan.construct() + mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True) + + # Generate inputs for the GEMM + A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)] + + # Run the module + D = mod.run(A, B, C) + + +Example usage without JIT compilation: + +.. highlight:: python +.. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + op = plan.construct() + cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output') + +After this call, the directory ``output`` contains ``setup.py``, +``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from +within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``. + +The module can later be used in Python via: + +.. highlight:: python +.. code-block:: python + + import torch + import cutlass_gemm + + # Generate inputs for the GEMM + A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)] + + # Run the module + D = cutlass_gemm.run(A, B, C) +""" + +import logging +import os + +from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate + +from cutlass_cppgen import CUTLASS_PATH, logger, swizzle +from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal +from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation +from cutlass_cppgen.backend.library import ApiVersion +from cutlass_cppgen.emit import common +from cutlass_cppgen.utils.datatypes import is_torch_available + +if is_torch_available(): + import torch + + +_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" + +// helper function allocating the memory +void* device_memory_allocation(size_t size, int device_id=0) { + if (size > 0) { + torch::Device device(torch::kCUDA, device_id); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device); + at::Tensor device_tensor = torch::empty({(long)size,}, options); + return reinterpret_cast(device_tensor.data_ptr()); + } else { + return nullptr; + } +} + +${includes} +${declaration} +${impl} +""" + +_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, float alpha=1.f, float beta=0.f); + +// C++ interface +at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, float alpha=1.f, float beta=0.f) { + return ${name}_kernel(A, B, C, alpha, beta); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", py::overload_cast, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f); +} +""" + +_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +std::vector ${name}_kernel(const std::vector& A, const std::vector& B, at::optional> C=at::nullopt, float alpha=1.f, float beta=0.f); + +// C++ interface +std::vector ${name}(const std::vector& A, const std::vector& B, at::optional> C=at::nullopt, float alpha=1.f, float beta=0.f) { + return ${name}_kernel(A, B, C, alpha, beta); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", py::overload_cast&, const std::vector&, at::optional>, float, float>(&${name}), + py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f); +} +""" + +_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +at::Tensor ${name}_kernel( + const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1); + +// C++ interface +at::Tensor ${name}( + const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", + py::overload_cast< + const at::Tensor&, const at::Tensor&, at::optional, + std::tuple, std::tuple, std::tuple, float, float, std::string, int>( + &${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, + py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1), + py::arg("alpha") = 1.f, py::arg("beta") = 0.f, + py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1); +} +""" + +_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +at::Tensor ${name}_kernel( + std::tuple result_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1); + +// C++ interface +at::Tensor ${name}( + std::tuple result_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", + py::overload_cast< + std::tuple, const at::Tensor&, const at::Tensor&, at::optional, + std::tuple, std::tuple, std::tuple, float, float, std::string, int>( + &${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, + py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1), + py::arg("alpha") = 1.f, py::arg("beta") = 0.f, + py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1); +} +""" + +_PYTORCH_GEMM_INCLUDES = { + ApiVersion.v2x: """ +#include "cutlass/gemm/device/gemm_universal.h" +""", + ApiVersion.v3x: """ +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/util/packed_stride.hpp" +""", +} + +_PYTORCH_GROUPED_GEMM_INCLUDES = """ +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" +""" + +_PYTORCH_CONV2D_INCLUDES = """ +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/kernel/default_conv2d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +""" + +_CUTLASS_TYPE_TO_TORCH_TYPE = { + DataType.f16: "torch::kF16", + DataType.f32: "torch::kF32", + DataType.f64: "torch::kF64", + DataType.s8: "torch::kI8", + DataType.s32: "torch::kI32", + DataType.bf16: "torch::kBFloat16", +} + +_PYTORCH_GEMM_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_GEMM_2x + + """ +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C, float alpha, float beta) { + int M = A.size(0); + int N = B.size(1); + int K = A.size(1); + + typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->contiguous().data_ptr()); + at::Tensor D = B.new_empty({M, N}, ${torch_type_C}); + + cutlass::Status status = ${name}_kernel_run(M, N, K, + reinterpret_cast(A.contiguous().data_ptr()), + reinterpret_cast(B.contiguous().data_ptr()), + ptrC, + reinterpret_cast(D.contiguous().data_ptr()), + ElementCompute(alpha), ElementCompute(beta)); + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" +) + +_PYTORCH_GEMM_IMPL_TEMPLATE_3x = ( + common._CUTLASS_KERNEL_RUN_GEMM_3x + + """ +bool hw_info_queried = false; +cutlass::KernelHardwareInfo hw_info; + +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C, float alpha, float beta) { + int M = A.size(0); + int N = B.size(1); + int K = A.size(1); + int L = 1; + + // Query hardware info if we haven't already + if (!hw_info_queried) { + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->contiguous().data_ptr()); + at::Tensor D = B.new_empty({M, N}, ${torch_type_C}); + + cutlass::Status status = ${name}_kernel_run(M, N, K, L, + reinterpret_cast(A.contiguous().data_ptr()), + reinterpret_cast(B.contiguous().data_ptr()), + ptrC, + reinterpret_cast(D.contiguous().data_ptr()), + ElementCompute(alpha), ElementCompute(beta), + hw_info); + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" +) + + +_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = ( + common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x + + """ +std::vector ${name}_kernel(const std::vector& A, const std::vector& B, at::optional> C, float alpha, float beta) { + size_t num = A.size(); + + // To avoid performing many small cudaMallocs and host-to-device copies, + // we serialize the grouped GEMM arguments on the host, allocate one + // large chunk of device memory, and perform a single cudaMemcpy to + // copy the host data to the device. Allocation overheads could be + // avoided by using a memory pool. + + // Calculate the total size of the data to be copied from host to device + size_t total_size = sizeof(cutlass::gemm::GemmCoord) + + sizeof(DeviceKernel::ElementA*) + + sizeof(DeviceKernel::ElementB*) + + sizeof(DeviceKernel::ElementC*) + + sizeof(DeviceKernel::ElementC*) + + sizeof(int64_t) + + sizeof(int64_t) + + sizeof(int64_t); + total_size *= num; + + // num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple + // of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system). + // To ensure that we don't end up having misaligned loads in the kernel, + // we pad to the nearest multiple of 8. + // + // Note that, even on a 32-bit system (for which sizeof(X*) will not equal + // sizeof(int64_t)), only padding between the list of GemmCoords and the + // list of ptr_As is sufficient because the set of four equal-length lists of pointers + // (A*, B*, C*, D*) will ensure that the first list of int64_ts will always + // start on a multiple of 8. + int64_t padding = 8 - (total_size % 8); + total_size += padding; + + uint8_t* host_data = new uint8_t[total_size]; + cutlass::DeviceAllocation device_data(total_size); + + uint8_t* start = host_data; + cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast(start); + + // Apply the padding after the list of GemmCoords + start += num * sizeof(cutlass::gemm::GemmCoord) + padding; + + int64_t ptr_A_offset = start - host_data; + DeviceKernel::ElementA** ptr_A_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementA*); + + int64_t ptr_B_offset = start - host_data; + DeviceKernel::ElementB** ptr_B_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementB*); + + int64_t ptr_C_offset = start - host_data; + DeviceKernel::ElementC** ptr_C_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementC*); + + int64_t ptr_D_offset = start - host_data; + DeviceKernel::ElementC** ptr_D_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementC*); + + int64_t lda_offset = start - host_data; + int64_t* lda_host = reinterpret_cast(start); + start += num * sizeof(int64_t); + + int64_t ldb_offset = start - host_data; + int64_t* ldb_host = reinterpret_cast(start); + start += num * sizeof(int64_t); + + int64_t ldc_offset = start - host_data; + int64_t* ldc_host = reinterpret_cast(start); + start += num * sizeof(int64_t); + + std::vector D(num); + + bool need_C = (C != at::nullopt) && (beta != 0.f); + for (size_t i = 0; i < num; ++i) { + int M = A[i].size(0); + int N = B[i].size(1); + int K = A[i].size(1); + *(problem_sizes_host + i) = {M, N, K}; + *(ptr_A_host + i) = reinterpret_cast(A[i].contiguous().data_ptr()); + *(ptr_B_host + i) = reinterpret_cast(B[i].contiguous().data_ptr()); + + if (need_C) { + *(ptr_C_host + i) = reinterpret_cast(C->at(i).contiguous().data_ptr()); + } + else { + *(ptr_C_host + i) = nullptr; + } + + D[i] = B[i].new_empty({M, N}, ${torch_type_C}); + *(ptr_D_host + i) = reinterpret_cast(D[i].contiguous().data_ptr()); + + *(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0); + *(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0); + *(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0); + } + + device_data.copy_from_host(host_data); + + cutlass::Status status = ${name}_kernel_run( + num, + reinterpret_cast(device_data.get()), + reinterpret_cast(device_data.get() + ptr_A_offset), + reinterpret_cast(device_data.get() + ptr_B_offset), + reinterpret_cast(device_data.get() + ptr_C_offset), + reinterpret_cast(device_data.get() + ptr_D_offset), + reinterpret_cast(device_data.get() + lda_offset), + reinterpret_cast(device_data.get() + ldb_offset), + reinterpret_cast(device_data.get() + ldc_offset), + reinterpret_cast(device_data.get() + ldc_offset), + ElementCompute(alpha), ElementCompute(beta)); + + delete[] host_data; + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" +) + +_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + cutlass::Status status = ${name}_kernel_run( + &problem_size, + reinterpret_cast(A.data_ptr()), + reinterpret_cast(B.data_ptr()), + ptrC, + reinterpret_cast(D.data_ptr()), + alpha, beta, + split_k_mode, stream, B.device().index()); + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" + +_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_CONV2D_2x + + """ +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) { + int N, H, W, C_, K, R, S, P, Q; + N = A.size(0); + C_ = A.size(1); + H = A.size(2); + W = A.size(3); + + K = B.size(0); + R = B.size(2); + S = B.size(3); + + cutlass::conv::Conv2dProblemSize problem_size( + cutlass::Tensor4DCoord(N, H, W, C_), + cutlass::Tensor4DCoord(K, R, S, C_), + cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), + cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), + cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), + cutlass::conv::Mode::kCrossCorrelation, + split_k_slices + ); + + P = problem_size.P; + Q = problem_size.Q; + + typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->data_ptr()); + + torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); + at::Tensor D = torch::zeros({N, K, P, Q}, options); +""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x +) + + +_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_CONV2D_2x + + """ +at::Tensor ${name}_kernel(std::tuple input_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + int N, H, W, C_, K, R, S; + N = std::get<0>(input_size); + C_ = std::get<1>(input_size); + H = std::get<2>(input_size); + W = std::get<3>(input_size); + + K = B.size(0); + R = B.size(2); + S = B.size(3); + + cutlass::conv::Conv2dProblemSize problem_size( + cutlass::Tensor4DCoord(N, H, W, C_), + cutlass::Tensor4DCoord(K, R, S, C_), + cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), + cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), + cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), + cutlass::conv::Mode::kCrossCorrelation, + split_k_slices + ); + + typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->data_ptr()); + + torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); + at::Tensor D = torch::empty({N, C_, H, W}, options); +""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x +) + + +_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_CONV2D_2x + + """ +at::Tensor ${name}_kernel(std::tuple weight_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + int N, H, W, C_, K, R, S; + K = std::get<0>(weight_size); + C_ = std::get<1>(weight_size); + R = std::get<2>(weight_size); + S = std::get<3>(weight_size); + + N = B.size(0); + H = B.size(2); + W = B.size(3); + + cutlass::conv::Conv2dProblemSize problem_size( + cutlass::Tensor4DCoord(N, H, W, C_), + cutlass::Tensor4DCoord(K, R, S, C_), + cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), + cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), + cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), + cutlass::conv::Mode::kCrossCorrelation, + split_k_slices + ); + + typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->data_ptr()); + + torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); + at::Tensor D = torch::empty({K, C_, R, S}, options); +""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x +) + + +_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='${name}', + ext_modules=[ + CUDAExtension('${name}', [ + '${name}.cpp', + '${name}_kernel.cu', + ], + include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'], + extra_compile_args={ + 'cxx': ['-std=c++17'], + 'nvcc': ['-std=c++17', ${extra_compile_args}], + }, + libraries=['cuda'] + ), + ], + cmdclass={ + 'build_ext': BuildExtension + }) + +""" + + +def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""): + """ + Generates a setup.py file for the extension + + :param name: name of the module to generate + :type name: str + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + :param extra_compile_args: additional arguments to pass to setup.py + :type extra_args: str + """ + setup_py_file = os.path.join(sourcedir, "setup.py") + setup_source = SubstituteTemplate( + _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args} + ) + with open(setup_py_file, "w") as outfile: + outfile.write(setup_source) + + +class _ArchListSetter: + """ + Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST`` + environment variable when building a PyTorch CUDA module. + + ``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch + CUDA module should be compiled. + + For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of + ``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the + compilation of the module. + + This utility wraps the building of a PyTorch CUDA module with a setting of this environment + variable according to the current compute capability being targetted. + + Example usage: + + .. highlight:: python + .. code-block:: python + + # Temporarily set TORCH_CUDA_ARCH_LIST="8.0" + with _ArchListSetter(80): + # Perform JIT compilation and loading of the module + mod = torch.utils.cpp_extension.load(...) + + :param cc: compute capability + :type cc: int + """ + + _TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST" + + def __init__(self, cc: int): + self.cc_str = ".".join(list(str(cc))) + + def __enter__(self): + """ + Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc`` + """ + self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST) + os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str + + return self + + def __exit__(self, exc_type, exc_val, traceback): + """ + Restores the old value of TORCH_CUDA_ARCH_LIST + """ + if self.old_arch_list is None: + del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] + else: + os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list + + +def _jit(name: str, cc: int, cpp_file: str, cuda_file: str): + """ + JIT compiles and loads a PyTorch CUDA extension. + + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param cpp_file: path to file containing extension's C++ interface + :type cpp_file: str + :param cuda_file: path to file containing extension's CUDA interface + :type cuda_file: str + + :return: loaded PyTorch module + """ + + from torch.utils.cpp_extension import load + + extra_cuda_cflags = ["-std=c++17"] + if cc in [90, 100, 101, 103]: + # PyTorch does not currently add the sm_90a target when compute capability + # 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target. + extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a") + + with _ArchListSetter(cc): + jitmodule = load( + name, + [cpp_file, cuda_file], + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=[ + os.path.join(CUTLASS_PATH, "include"), + os.path.join(CUTLASS_PATH, "tools/util/include"), + ], + extra_ldflags=["-lcuda"], + verbose=(logger.level == logging.DEBUG) + ) + return jitmodule + + +def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise + """ + if sourcedir != "" and not os.path.isdir(sourcedir): + os.makedirs(sourcedir) + + cuda_file = os.path.join(sourcedir, name + "_kernel.cu") + extra_kw = {} + if op.api == ApiVersion.v3x: + impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x + else: + impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x + if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK: + extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K + else: + extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x + impl_template = ( + _PYTORCH_GEMM_IMPL_TEMPLATE_3x + if op.api == ApiVersion.v3x + else _PYTORCH_GEMM_IMPL_TEMPLATE_2x + ) + cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw}) + cuda_source = SubstituteTemplate( + _PYTORCH_CUDA_TEMPLATE, + { + "includes": _PYTORCH_GEMM_INCLUDES[op.api], + "declaration": op.rt_module.emit(), + "procedural_name": op.procedural_name(), + "impl": cuda_impl, + "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], + }, + ) + with open(cuda_file, "w") as outfile: + outfile.write(cuda_source) + + cpp_file = os.path.join(sourcedir, name + ".cpp") + cpp_source = SubstituteTemplate( + _PYTORCH_GEMM_CPP_TEMPLATE, + {"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"}, + ) + with open(cpp_file, "w") as outfile: + outfile.write(cpp_source) + + extra_compile_args = "" + if cc in [90, 100, 101, 103]: + extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'" + _generate_setup(name, sourcedir, extra_compile_args) + + if jit: + return _jit(name, cc, cpp_file, cuda_file) + + return None + + +def _pytorch_grouped_gemm( + op, name: str, cc: int, jit: bool = False, sourcedir: str = "" +): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise + """ + if op.api != ApiVersion.v2x: + raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x") + + if sourcedir != "" and not os.path.isdir(sourcedir): + os.makedirs(sourcedir) + + cuda_file = os.path.join(sourcedir, name + "_kernel.cu") + cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name}) + cuda_source = SubstituteTemplate( + _PYTORCH_CUDA_TEMPLATE, + { + "includes": _PYTORCH_GROUPED_GEMM_INCLUDES, + "declaration": op.rt_module.emit(), + "procedural_name": op.procedural_name(), + "impl": cuda_impl, + "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], + }, + ) + with open(cuda_file, "w") as outfile: + outfile.write(cuda_source) + + cpp_file = os.path.join(sourcedir, name + ".cpp") + cpp_source = SubstituteTemplate( + _PYTORCH_GROUPED_GEMM_CPP_TEMPLATE, + {"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"}, + ) + with open(cpp_file, "w") as outfile: + outfile.write(cpp_source) + + _generate_setup(name, sourcedir) + + if jit: + return _jit(name, cc, cpp_file, cuda_file) + + return None + + +def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or + weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions + for H/W/R/S given the same P/Q. + + :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise + """ + if sourcedir != "" and not os.path.isdir(sourcedir): + os.makedirs(sourcedir) + cuda_file = os.path.join(sourcedir, name + "_kernel.cu") + extra_kw = {} + if op.conv_kind == ConvKind.Fprop: + impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x + cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE + elif op.conv_kind == ConvKind.Dgrad: + impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x + cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE + elif op.conv_kind == ConvKind.Wgrad: + impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x + cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE + extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize() + extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element] + cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw}) + cuda_source = SubstituteTemplate( + _PYTORCH_CUDA_TEMPLATE, + { + "includes": _PYTORCH_CONV2D_INCLUDES, + "declaration": op.rt_module.emit(), + "procedural_name": op.procedural_name(), + "impl": cuda_impl, + "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], + }, + ) + with open(cuda_file, "w") as outfile: + outfile.write(cuda_source) + + cpp_file = os.path.join(sourcedir, name + ".cpp") + cpp_source = SubstituteTemplate( + cpp_template, + {"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"}, + ) + with open(cpp_file, "w") as outfile: + outfile.write(cpp_source) + + _generate_setup(name, sourcedir) + + if jit: + return _jit(name, cc, cpp_file, cuda_file) + + return None + + +def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + The result of this method is files within ``sourcedir`` that can be used for building + a PyTorch module. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + :return: loaded PyTorch module (if ``jit=True``) or None + """ + device_op = op.device_op() + if isinstance(op, GemmOperationUniversal): + return _pytorch_gemm(device_op, name, cc, jit, sourcedir) + elif isinstance(op, GemmOperationGrouped): + return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir) + elif isinstance(op, Conv2dOperation): + return _pytorch_conv2d(device_op, name, cc, jit, sourcedir) + else: + raise Exception( + f"Operation type {type(op)} is not currently supported for PyTorch emission." + ) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faf6896e99ba78130ede8e09be9b9115e9169541 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py @@ -0,0 +1,56 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.epilogue.epilogue import ( + get_activations, + get_activation_epilogue, + gelu, + hardswish, + identity, + leaky_relu, + relu, + sigmoid, + silu, + tanh, + trace +) + +from cutlass_cppgen.epilogue.evt_ops import ( + max, + multiply_add, + sum, + permute, + reshape, + maximum, + minimum, + exp +) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py new file mode 100644 index 0000000000000000000000000000000000000000..a3a17506ee2be609ed8d5b299114df52c55ca0cf --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py @@ -0,0 +1,176 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Registry of elementwise epilogues + +Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via +code like the following for GEMM: + +.. highlight:: python +.. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.activation = cutlass_cppgen.epilogue.relu +""" + +from cutlass_cppgen.backend import epilogue, device_cc + + +gelu = epilogue.gelu +hardswish = epilogue.hardswish +identity = epilogue.identity +leaky_relu = epilogue.leaky_relu +relu = epilogue.relu +sigmoid = epilogue.sigmoid +silu = epilogue.silu +tanh = epilogue.tanh + + +_activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh] + + +def get_activations() -> list: + """ + Returns a list of available activation functions + + :return: list of available activation functions + :rtype: list + """ + return _activations + + +def get_activation_epilogue( + activation, + element_output, + elements_per_access, + element_accumulator, + element_compute, +): + """ + Return an epilogue corresponding to the activation function, data types, and alignment + used in the kernel + + :param activation: elementwise activation function to use + :param element_output: data type of the output + :param elements_per_access: alignment of operand C of the kernel + :type elements_per_access: int + :param element_accumulator: data type of the accumulated output C + :param element_compute: data type in which compute operations should be performed + + :return: epilogue functor + """ + if activation not in _activations: + raise Exception( + f"Unsupported activation type {activation}. Available activations are: {_activations}" + ) + + if activation == identity: + return epilogue.LinearCombination( + element_output, elements_per_access, element_accumulator, element_compute + ) + else: + return epilogue.LinearCombinationGeneric( + activation, + element_output, + elements_per_access, + element_accumulator, + element_compute, + ) + + +""" +Frontend for EVT that generates epilogue functor through tracing the input function +""" +from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend + + +def trace(fn, example_tensors, **kwargs): + """ + Trace `fn(**example_tensors)` and generates epilogue visitor + + :param fn or str: Python callable or string of the epilogue function + :param example_tensors: example inputs for fn + :type example_tensors: dict + + .. hightlight:: python + .. code-block:: python + import cutlass_cppgen.backend.evt + + # Define epilogue function as Python callable + def example_fn(accum, C, alpha, beta, gamma): + D = ((accum + C) * alpha - gamma) / beta + return D + + # Define the example tensors + example_inputs = { + "accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"), + "C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"), + "alpha": 1.5, + "beta": 0.5, + "gamma": 2.5, + "D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda") + } + + # Generate the epilogue functor + epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs) + """ + if callable(fn): + class EpilogueFunctor(PythonASTFrontend): + def __init__(self, cc=None, **kwargs): + if not cc: + cc = device_cc() + super().__init__(cc, **kwargs) + pass + setattr(EpilogueFunctor, "__call__", staticmethod(fn)) + + epilogue_functor = EpilogueFunctor(**kwargs) + epilogue_functor.trace(example_tensors) + return epilogue_functor + elif isinstance(fn, str): + class EpilogueFunctor(PythonASTFrontend): + def __init__(self, cc=None, **kwargs): + self.source = textwrap.dedent(fn) + if not cc: + cc = device_cc() + super().__init__(cc, **kwargs) + + def parse(self, example_inputs) -> None: + self.example_inputs = example_inputs + self.ast = ast.parse(self.source) + self.visit(self.ast) + + epilogue_functor = EpilogueFunctor(**kwargs) + epilogue_functor.trace(example_tensors) + return epilogue_functor + else: + raise NotImplementedError("Expect a callable Python function") diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7d8e2c01286886ffc936052c84205a60a5d869fb --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py @@ -0,0 +1,98 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Collection of builtin functions used for host reference in EVT +""" + +import numpy as np + +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor + +if is_torch_available(): + import torch + + +def multiply_add(x, y, z): + return x * y + z + + +def sum(x, dim): + if is_numpy_tensor(x): + return x.sum(axis=tuple(dim)) + elif is_torch_tensor(x): + return torch.sum(x, dim) + + +def max(x, dim): + if is_numpy_tensor(x): + return x.max(axis=tuple(dim)) + elif is_torch_tensor(x): + return torch.amax(x, dim) + + +def maximum(x, y): + if is_numpy_tensor(x): + return np.maximum(x, y) + elif is_torch_tensor(x): + return torch.maximum(x, torch.tensor(y)) + + +def minimum(x, y): + if is_numpy_tensor(x): + return np.minimum(x, y) + elif is_torch_tensor(x): + return torch.minimum(x, torch.tensor(y)) + +def exp(x): + if is_numpy_tensor(x): + return np.exp(x) + elif is_torch_tensor(x): + return torch.exp(x) + + +############################################################################## +# Layout manipulate nodes +############################################################################## + +def permute(x, indices: tuple): + if is_numpy_tensor(x): + return np.transpose(x, axes=indices) + elif is_torch_tensor(x): + return x.permute(*indices) + + +def reshape(x, new_shape: tuple): + if is_numpy_tensor(x): + return np.reshape(x, newshape=new_shape) + elif is_torch_tensor(x): + return x.view(new_shape) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ea04419955f6a71225b6daaeab884dcc4e3399 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py @@ -0,0 +1,569 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Classes containing valid operations for a given compute capability and data types. +""" + +from itertools import combinations_with_replacement +import logging + +import cutlass_library +from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode + +import cutlass_cppgen +from cutlass_cppgen.utils.check import valid_stage_count +from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op + + +_generator_ccs = [50, 60, 61, 70, 75, 80, 90, 100] + + +class KernelsForDataType: + """ + Container class for keeping track of kernels that correspond to a particular combination + of data types for operands A, B, and accumulator + """ + + def __init__(self, datatype_comb: tuple, layout_comb: tuple): + self.datatype_comb = datatype_comb + self.layout_comb = layout_comb + self.math_operations = set() + + # Dictionary mapping from alignment (int) to a list of kernels that fit the alignment + # constraint for the data type combination + self.kernels_by_alignment = {} + + def add(self, operation): + """ + Add an operation to the list of supported kernels + """ + alignment_key = f"{operation.A.alignment} {operation.B.alignment} {operation.C.alignment}" + if alignment_key not in self.kernels_by_alignment: + self.kernels_by_alignment[alignment_key] = [] + self.kernels_by_alignment[alignment_key].append(operation) + self.math_operations.add(operation.tile_description.math_instruction.math_operation) + + def alignments(self, operand: str): + """ + Returns an unsorted list of alignments supported by this data type combination + + :param operand: identifier of operand in question (e.g., A, B, C) + :type operand: str + + :return: unsorted list of alignments supported by this data type combination + :rtype: list + """ + operand_idx = self._operand_idx(operand) + return [int(key.split(" ")[operand_idx]) for key in self.kernels_by_alignment.keys()] + + @property + def all_operations(self): + """ + Returns a list of all operations supported by this data type combination + + :return: list of all operations supported by this data type combination + :rtype: list + """ + ops = [] + for _, alignment_ops in self.kernels_by_alignment.items(): + ops.extend(alignment_ops) + return ops + + def default_operation(self, math_operation: cutlass_cppgen.MathOperation): + key = sorted(list(self.kernels_by_alignment.keys()))[0] + kernels = self.kernels_by_alignment[key] + if math_operation is not None: + kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation] + return kernels[0] + + def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation): + """ + Returns operations satisfying the alignment constraints + + :param alignment_A: alignment constraint of operations to return + :type alignment_A: int + :param alignment_B: alignment constraint of operations to return + :type alignment_B: int + :param alignment_C: alignment constraint of operations to return + :type alignment_C: int + :param math_operation: math operation to consider + :type math_operation: cutlass_cppgen.MathOperation + + :return: list of operations + :rtype: list + """ + key = f"{alignment_A} {alignment_B} {alignment_C}" + + if key not in self.kernels_by_alignment: + og_key = key + # Reconcile A, B, and C alignments by trying to align to the minimum + min_alignment = min(alignment_A, alignment_B, alignment_C) + key = f"{min_alignment} {min_alignment} {min_alignment}" + if key not in self.kernels_by_alignment: + # Finally, go through all available alignment combinations and find + # one for which all values are less than those passed in. + key = None + alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True) + for align_A, align_B, align_C in alignments: + if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0: + key = f"{align_A} {align_B} {align_C}" + break + + if key is None: + raise Exception( + f"No operations of alignment {og_key} found for data type and layout " + f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments " + f"are {self.kernels_by_alignment.keys()}" + ) + + ops = self.kernels_by_alignment[key] + if math_operation is not None: + ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation] + return ops + + def _operand_idx(self, key: str) -> int: + operand_list = ["A", "B", "C"] + if key not in operand_list: + raise Exception(f"Unexpected operand {operand}") + + return operand_list.index(key) + + def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int: + """ + Returns the most preferable alignment for a given shape and layout + + :param shape: extent of each dimension of the tensor + :type shape: tuple + :param layout: layout of the tensor + :type layout: cutlass_cppgen.LayoutType + :param operand: descriptor of the operand in question + :type operand: str + + :return: maximum alignment supported by the data type combination and tensor size + :rtype: int + """ + operand_idx = self._operand_idx(operand) + + # Determine the leading dimension of the shape + if layout == cutlass_cppgen.LayoutType.ColumnMajor: + ld = shape[-2] + elif layout == cutlass_cppgen.LayoutType.RowMajor: + ld = shape[-1] + elif layout == cutlass_cppgen.LayoutType.TensorNHWC: + ld = shape[-1] + else: + raise Exception(f"Unexpected or unsupported layout {layout}") + + for alignments in sorted(list(self.kernels_by_alignment.keys()), reverse=True): + alignment = int(alignments.split(" ")[operand_idx]) + if ld % alignment == 0: + return alignment + + # Default to alignment of 1 if no others match + return 1 + + def sort(self): + """ + Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape + """ + key = lambda op: ( + op.tile_description.threadblock_shape[0] + * op.tile_description.threadblock_shape[1] + * op.tile_description.threadblock_shape[2] + ) + for alignment in self.kernels_by_alignment.keys(): + self.kernels_by_alignment[alignment].sort(key=key, reverse=True) + + def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool: + """ + Returns whether `math_operation` is supported by at least one operation. + + :param math_operation: math operation to consider + :type math_operation: cutlass_cppgen.MathOperation + + :return: whether math_operation is supported by at least one operation + :rtype: bool + """ + return math_operation is None or math_operation in self.math_operations + + +class ArchOptions: + """ + Structure for keeping track of kernels available on a given compute capability + + :param target_cc: compute capability of the device on which kernels will be run + :type target_cc: int + :param kernel_cc: compute capability of the kernels to generate + :type kernel_cc: int + :param operation_kind: type of operation to register + :type operation_kind: cutlass_library.OperationKind + :param gemm_kinds: types of GEMM operations that can be included + :type gemm_kinds: list + :param allowed_math_operations: types of primitive math operations allowed + :type allowed_math_operations: list + """ + + def __init__( + self, + target_cc: int, + kernel_cc: int, + operation_kind: cutlass_library.OperationKind, + gemm_kinds: list, + allowed_math_operations: list = [ + cutlass_library.MathOperation.multiply_add, + cutlass_library.MathOperation.multiply_add_saturate, + cutlass_library.MathOperation.multiply_add_mixed_input_upcast, + cutlass_library.MathOperation.multiply_add_fast_f32 + ] + ): + self.cc = kernel_cc + + # Dictionary with following structure: + # Key: OpcodeClass + # Value: Dictionary with the following structure: + # Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType), + # representing ((element_a, element_b, element_accumulator), (layout_a, layout_b)) + # Value: KernelsForDataType + self.operations_by_opclass = {} + self.op_class = None + self.allowed_math_operations = allowed_math_operations + + if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100: + return + + # Identify the method within CUTLASS generator script that generates kernel + # descriptions for the target CC + generate_function_name = "GenerateSM" + str(kernel_cc) + if not hasattr(cutlass_library.generator, generate_function_name): + cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}") + return + generate_function = getattr(cutlass_library.generator, generate_function_name) + + # Initialize a default manifest and populate it with valid kernel descriptions + # for the target CC + args = [ + "--kernels=all", + f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}" + ] + manifest_args = cutlass_library.generator.define_parser().parse_args(args) + manifest = cutlass_library.manifest.Manifest(manifest_args) + generate_function(manifest, cutlass_cppgen._nvcc_version) + + if operation_kind not in manifest.operations: + # No kernels generated for this architecture, this could be because the CUDA + # toolkit is insufficient to support operations in this CC + cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}") + return + + # Only one CC should be returned, given the setup above of calling only the generation scripts + # for a given CC + if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]: + raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version " + "is sufficient for the architecture in question.") + + # Iterate through the available operations for this operation kind and + # find available opclasses and data types + for name, op_list in manifest.operations[operation_kind][kernel_cc].items(): + for op in op_list: + + if operation_kind == cutlass_library.OperationKind.Gemm: + if op.gemm_kind not in gemm_kinds: + continue + + mi = op.tile_description.math_instruction + if mi.math_operation not in self.allowed_math_operations: + continue + + # Prune operations that don't fit in shared memory + td = td_from_profiler_op(op) + if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]: + continue + + if mi.opcode_class not in self.operations_by_opclass: + self.operations_by_opclass[mi.opcode_class] = {} + + datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator) + layout_comb = (op.A.layout, op.B.layout) + + # Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations + if datatype_comb == (cutlass_library.DataType.tf32, cutlass_library.DataType.tf32, cutlass_library.DataType.f32): + # TF32 kernels only supported on SM80 and beyond + if self.cc < 80: + continue + elif self.cc == 90 or self.cc == 100: + if (op.A.element != cutlass_library.DataType.f32 + or op.B.element != cutlass_library.DataType.f32 + or op.C.element != cutlass_library.DataType.f32): + continue + + datatype_comb = (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32) + + opclass_dict = self.operations_by_opclass[mi.opcode_class] + key = (datatype_comb, layout_comb) + if key not in opclass_dict: + opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb) + opclass_dict[key].add(op) + + # Set the default opclass to TensorOp, if available. Otherwise default to SIMT + if cutlass_library.OpcodeClass.TensorOp in self.operations_by_opclass: + self.op_class = cutlass_library.OpcodeClass.TensorOp + else: + self.op_class = cutlass_library.OpcodeClass.Simt + + # The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels. + # Here, we generate additional versions via a generic TileDescription. + if cutlass_library.OpcodeClass.Simt not in self.operations_by_opclass: + self.operations_by_opclass[cutlass_library.OpcodeClass.Simt] = {} + + if operation_kind == cutlass_library.OperationKind.Gemm: + types = [ + (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s8), + (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32), + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16), + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32), + (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32), + (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64), + ] + + # Add FP8 A/B/C + fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2] + for type_comb in combinations_with_replacement(fp8_types, 3): + types.append(type_comb) + + # Add FP8 A/B with FP32 C + for type_comb in combinations_with_replacement(fp8_types, 2): + types.append(type_comb + (cutlass_cppgen.DataType.f32,)) + + layouts = [ + (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor), + (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor), + (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.RowMajor), + (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.ColumnMajor), + ] + elif operation_kind == cutlass_library.OperationKind.Conv2d: + types = [ + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16), + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32), + (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32), + (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64), + ] + + layouts = [ + (cutlass_library.LayoutType.TensorNHWC, cutlass_library.LayoutType.TensorNHWC), + ] + else: + raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.") + + alignment = 1 + epilogue_functor = cutlass_library.EpilogueFunctor.LinearCombination + swizzling_functor = cutlass_library.SwizzlingFunctor.Identity8 + for type_comb in types: + for layout_comb in layouts: + comb = (type_comb, layout_comb) + if comb in self.operations_by_opclass[cutlass_library.OpcodeClass.Simt]: + continue + + A = cutlass_library.TensorDescription(type_comb[0], layout_comb[0], alignment) + B = cutlass_library.TensorDescription(type_comb[1], layout_comb[1], alignment) + C = cutlass_library.TensorDescription(type_comb[2], cutlass_library.LayoutType.ColumnMajor, alignment) + math_inst = cutlass_library.MathInstruction( + [1, 1, 1], + type_comb[0], + type_comb[1], + type_comb[2], + cutlass_library.OpcodeClass.Simt, + cutlass_library.MathOperation.multiply_add + ) + + td = cutlass_library.TileDescription( + [128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024) + + # Prune operations that don't fit in shared memory + if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td), verbose=False)[0]: + continue + + new_kernels = KernelsForDataType(type_comb, layout_comb) + + if operation_kind == cutlass_library.OperationKind.Gemm: + new_operation = cutlass_library.manifest.GemmOperation( + cutlass_library.GemmKind.Universal, td.minimum_compute_capability, + td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor) + new_kernels.add(new_operation) + elif operation_kind == cutlass_library.OperationKind.Conv2d: + for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: + new_operation = cutlass_library.manifest.Conv2dOperation( + conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td, + A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor, + group_mode=GroupMode.SingleGroup + ) + new_kernels.add(new_operation) + + self.operations_by_opclass[cutlass_library.OpcodeClass.Simt][comb] = new_kernels + + # Sort all operations + for oc in self.operations_by_opclass.keys(): + for comb in self.operations_by_opclass[oc].keys(): + self.operations_by_opclass[oc][comb].sort() + + def opclass_supports_combination( + self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation + ) -> bool: + """ + Returns whether the provided operation class supports the provided data type and layout combination + + :param op_class: operation class to consider + :type op_class: cutlass_library.OpcodeClass + :param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator) + :type datatype_comb: tuple[cutlass_library.DataType] + :param layout_comb: tuple of data types for (layout_A, layout_B) + :type layout_comb: tuple[cutlass_library.LayoutType] + :param math_operation: math operation to consider or None if any can be considered + :type math_operation: cutlass_cppgen.MathOperation + + :return: set of operation classes that support the provided data type and layout combination + :rtype: set + """ + if op_class not in self.operations_by_opclass: + raise Exception(f"Unexpected or unsupported operation class {op_class}") + + if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)): + if math_operation is not None: + return operations.supports_math_operation(math_operation) + else: + return True + + return False + + + def supporting_opclasses( + self, + element_a: cutlass_library.DataType, + element_b: cutlass_library.DataType, + element_accumulator: cutlass_library.DataType, + layout_a: cutlass_library.LayoutType, + layout_b: cutlass_library.LayoutType, + math_operation: cutlass_library.MathOperation, + ) -> set: + """ + Returns a set of operation classes that support the provided data type combination + + :param element_a: data type of operand A + :type element_a: cutlass_library.DataType + :param element_b: data type of operand B + :type element_b: cutlass_library.DataType + :param element_accumulator: data type of accumulator + :type element_accumulator: cutlass_library.DataType + :param layout_a: layout of operand A + :type layout_a: cutlass_library.LayoutType + :param layout_b: layout of operand B + :type layout_b: cutlass_library.LayoutType + :param math_operation: math operation to consider + :type math_operation: cutlass_cppgen.MathOperation + + :return: set of operation classes that support the provided data type combination + :rtype: set + """ + supporting_op_classes = set() + datatype_comb = (element_a, element_b, element_accumulator) + layout_comb = (layout_a, layout_b) + + for op_class in self.operations_by_opclass.keys(): + if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation): + supporting_op_classes.add(op_class) + return supporting_op_classes + + def operations( + self, + op_class: cutlass_library.OpcodeClass, + element_a: cutlass_library.DataType, + element_b: cutlass_library.DataType, + element_accumulator: cutlass_library.DataType, + layout_a: cutlass_library.LayoutType, + layout_b: cutlass_library.LayoutType, + math_operation: cutlass_library.MathOperation, + ) -> KernelsForDataType: + """ + Returns whether the provided operation class supports the provided data type combination + + :param op_class: operation class to consider + :type op_class: cutlass_library.OpcodeClass + :param element_a: data type of operand A + :type element_a: cutlass_library.DataType + :param element_b: data type of operand B + :type element_b: cutlass_library.DataType + :param element_accumulator: data type of accumulator + :type element_accumulator: cutlass_library.DataType + :param layout_a: layout of operand A + :type layout_a: cutlass_library.LayoutType + :param layout_b: layout of operand B + :type layout_b: cutlass_library.LayoutType + :param math_operation: math operation to consider + :type math_operation: cutlass_cppgen.MathOperation + + :return: container of kernels by alignment supported by the provided combination of parameters + :rtype: KernelsForDataType + """ + datatype_comb = (element_a, element_b, element_accumulator) + layout_comb = (layout_a, layout_b) + if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation): + raise Exception( + f"Data type layout combination {datatype_comb}, {layout_comb} " + f"is not supported by opcode class {op_class} on CC {self.cc}." + ) + return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)] + + +class OptionRegistry: + """ + Container of all architecture-specific options + + :param target_cc: compute capability of the device on which operations will be run + :type target_cc: int + """ + + def __init__(self, target_cc: int): + self.registry = {} + + if target_cc > 100 and (target_cc not in [101, 103, 120, 121]): + raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to the Blackwell architecture.") + + gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x] + operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d] + # Construct options for each CC + for kernel_cc in _generator_ccs: + self.registry[kernel_cc] = {} + for opkind in operation_kinds: + self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds) + + def options_for_cc(self, cc: int, op_kind=cutlass_library.OperationKind.Gemm) -> ArchOptions: + return self.registry.get(cc, None)[op_kind] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0286907040fb3ded84f989bfc9d14e740307f6a9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py @@ -0,0 +1,36 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad +from cutlass_cppgen.op.gemm import Gemm +from cutlass_cppgen.op.gemm_grouped import GroupedGemm +from cutlass_cppgen.op.op import OperationBase diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..711b27da13b54e30f8b25e839ffc4f51ed80dc5c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py @@ -0,0 +1,997 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running CONVs + + The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run + CONV2D operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS CONVs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # A, B, C, and D are torch/numpy/cupy tensor objects + plan = cutlass_cppgen.op.Conv(A, B, C, D) + plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + One can also use the interface by specifying data types of operands at construction + and using different tensor objects with these data types at runtime: + + .. highlight:: python + .. code-block:: python + + # The following is shorthand for: + # cutlass_cppgen.op.Conv2d(kind="fprop", + # element_A=torch.float32, element_B=torch.float32, + # element_C=torch.float32, element_D=torch.float32, + # element_accumulator=torch.float32) + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32) + + A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda') + B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda') + C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda') + D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda') + plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + A = torch.rand((32, 128), dtype=torch.float32, device='cuda') + B = torch.rand((128, 256), dtype=torch.float32, device='cuda') + C = torch.zeros((32, 256), dtype=torch.float32, device='cuda') + D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda') + plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + The interface additionally enables one to decouple the compilation of the underlying CUTLASS + kernel from its execution: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) + + # Do other work... + + plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + # Do other work... + + plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + Elementwise activation functions are easily fused to the GEMM via the interface: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) + plan.activation = cutlass_cppgen.epilogue.relu + + Operations can also be run asynchronously: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) + args = plan.run() + + # Do other work... + + args.sync() +""" + +from __future__ import annotations +from typing import Optional +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +from cutlass_library import ( + ConvKind, + ConvMode, + DataTypeSize, + IteratorAlgorithm, + OperationKind, + SplitKMode, + StrideSupport, +) + +import cutlass_cppgen +from cutlass_cppgen import epilogue +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation +from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments +from cutlass_cppgen.backend.library import TensorDescription, TileDescription +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord +from cutlass_cppgen.utils import check, datatypes + + +class Conv2d(OperationBase): + """ + Constructs a ``Conv2d`` object. + + The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C, + along with the data type of output D and that used for accumulation, are bound to the ``Conv`` + object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. The following + constructors are equivalent: + + .. highlight:: python + .. code-block:: python + + # Use F32 for A, B, C, D, and accumulation in fprop + + # Use the generic ``element`` parameter to concisely set all data types for operands to the same values. + Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32) + + # Explicitly specify the data types to use for A, B, C, and D. + Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, + element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32) + + # Set the data types and elements from existing tensors. Note that one can use different tensors when + # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must + # have the same data type as those passed in here). + # A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout + Conv2d(kind="fprop", A=A, B=B, C=C, D=D) + + # Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit + # those passed in via the generic ``element`` + Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, + element=cutlass_cppgen.DataType.f32) + + The order of precedence for the setting of the data type for a given operand/output is as follows: + 1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor + 2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those + 3) Otherwise, use the generic values (e.g., ``element``) + + :param kind: the convolution kind (i.e. fprop, wgrad, and dgrad) + :type kind: str + :param A: tensor representing data type of operand A + :param B: tensor representing data type of operand B + :param C: tensor representing data type of operand C + :param D: tensor representing data type of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass_cppgen.DataType + :param element_A: data type to be used for operand A + :type element_A: cutlass_cppgen.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass_cppgen.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass_cppgen.DataType + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass_cppgen.DataType + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + """ + def __init__( + self, kind="fprop", + A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0, + element=None, + element_A=None, element_B=None, element_C=None, element_D=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None + ): + super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d) + # Verify the kernel cc + if self.current_cc in [90, 100, 101, 103]: + # The Conv2d kernel on Hopper (SM90) is currently unsupported + # Revert to use SM80-tagged kernels + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + self.specified_kernel_cc = 80 + self._reset_options(80) + + # The arch is used in testing + self.arch = self.current_cc + self.name = "conv2d" + kind + + # The convolution kind. (concept: cutlass_library.library.ConvKind) + self.conv_kind = datatypes.getattr_enum(ConvKind, kind) + + # The element types (concept: cutlass library types) of A, B, C, and D + elements = [] + layouts = [] + + # Complete the data types based on user-provided arguments + for elt, tens, name in zip([element_A, element_B, element_C, element_D], + [A, B, C, D], + ["A", "B", "C", "D"]): + if elt is not None and tens is not None: + raise Exception(f'Must not specify both element_{name} and tensor {name}') + if elt is None and tens is None and element is None: + raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') + + elt_to_set = None + lay_to_set = None + + if tens is not None: + elt_to_set, _ = datatypes.get_datatype_and_layout(tens) + else: + elt_to_set = elt if elt is not None else element + + assert elt_to_set is not None + + # Currently we only support layout TensorNHWC + lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC + elements.append(datatypes.library_type(elt_to_set)) + layouts.append(lay_to_set) + + self._element_a, self._element_b, self._element_c, self._element_d = elements + self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts + + self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta + + if element_accumulator is None: + self._element_accumulator = self._element_c + else: + self._element_accumulator = datatypes.library_type(element_accumulator) + + # Default inputs if none is supplied in run() + self.A = A + self.B = B + self.C = C + self.D = D + + self.alpha = alpha + self.beta = beta + + # We only specify the stride of the swizzling functor here + # The actual swizzling functor is determined in run based on conv_kind and stride + self._swizzling_stride = 1 + + # Arguments that will be set to default value in _reset_operations + # The default tile_description and op_class are fetched from manifest of cutlass library + self._tile_description = None + self.op_class = None + # The default identity epilogue will be created + self.epilogue_functor = None + + self._reset_operations() + + # Arguments that will be determined online based on arguments of "run" + # based on stride, input/output channels, alignment, and conv_kind + self._iterator_algorithm = None + self._stride_support = None + + def _reset_operations(self, reset_epilogue: bool = True): + # Set the default op class + datatype_comb = (self._element_a, self._element_b, self._element_accumulator) + layout_comb = (self._layout_a, self._layout_b) + + self.possible_op_classes = self.options.supporting_opclasses( + self._element_a, self._element_b, self._element_accumulator, + self._layout_a, self._layout_b, self._math_operation + ) + + if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.TensorOp + elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.Simt + else: + if self._math_operation is not None: + math_op_str = f' and math operation {self._math_operation}' + else: + math_op_str = '' + + raise Exception(f'No kernel configuration found for supported data type and layout ' + f'combination {datatype_comb}x{layout_comb}{math_op_str}') + + if reset_epilogue: + self._reset_epilogue_functor_activation(epilogue.identity) + + self.alignment_pref_A = min( + 128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) + self.alignment_pref_B = min( + 128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) + self.alignment_pref_C = min( + 128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C"))) + + # + # Tile description Related + # + + @property + def tile_description(self) -> TileDescription: + """ + Returns the tile description + """ + return self._tile_description + + @tile_description.setter + def tile_description( + self, td=None): + """ + Set the tile description + + :param td: tile description + :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys + { + "threadblock_shape": [int, int, int], + "warp_count": [int, int, int], + "stages": int, + "instruction_shape": [int, int, int] (optional), + "cluster_shape": [int, int, int] (optional) + } + """ + if td is None: + return + if isinstance(td, dict): + if self._tile_description is None: + op = self.possible_operations.default_operation(self._math_operation) + self._tile_description = datatypes.td_from_profiler_op(op) + if "cluster_shape" in td.keys(): + if td["cluster_shape"] != [1, 1, 1]: + cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.") + td["cluster_shape"] = [1, 1, 1] + td = self._tile_description.clone_and_update(td) + + valid, msg = self._valid_tile_description(td) + if valid: + self._tile_description = td + else: + raise Exception(msg) + + def _valid_tile_description(self, td: TileDescription) -> tuple: + """ + Checks whether the provided tile description is valid for the given compute capability. At present, + this checks the following: + + - Does the tile description use a number of stages supported by the compute capability in question? + - Does the tile size requested fit within shared memory? + - Are cluster dimensions outside the valid range requested for a given architecture (e.g., + more non-unit cluster dimensions for pre-SM90 architectures)? + - Is the kernel schedule being used supported on the architecture in question? + + :param td: tile description to validate + :type td: cutlass_cppgen.backend.TileDescription + :return: tuple in which the first element is a bool indicating that the tile description is valid + and the second element is a string providing an optional error message. + :rtype: tuple + """ + valid, msg = check.valid_stage_count(self.cc, self.current_cc, td) + if not valid: + return (valid, msg) + + valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) + if not valid: + return (valid, msg) + + return valid, msg + + def tile_descriptions(self) -> list: + """ + Returns a list of valid tile descriptions for the operations + + :returns: list of valid tile descriptions for the operations + :rtype: list + """ + descriptions = [] + description_str = [] + for op in self.possible_operations.all_operations: + td = datatypes.td_from_profiler_op(op) + + if self._math_operation is not None: + if td.math_instruction.math_operation != self._math_operation: + continue + + if str(td) not in description_str: + description_str.append(str(td)) + descriptions.append(td) + return descriptions + + # + # Swizzling functor Related + # + + @property + def swizzling_stride(self): + """ + Returns the stride of swizzling currently being used by the Conv2d + + :return: swizzing stride + """ + return self._swizzling_stride + + @swizzling_stride.setter + def swizzling_stride(self, stride: int): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + if not isinstance(stride, int): + raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}") + self._swizzling_stride = stride + + def _propose_swizzling_functor(self, stride): + """ + Automatically propose the swizzling functor based on the stride + """ + if self.conv_kind == ConvKind.Dgrad: + if stride[0] != 1 or stride[1] != 1: + return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}") + + return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}") + + # + # Iterator Algorithm Related + # + + @property + def iterator_algorithm(self) -> IteratorAlgorithm: + """ + Returns the iterator algorithm + """ + return self._iterator_algorithm + + @iterator_algorithm.setter + def iterator_algorithm(self, alg: str): + """ + Sets the iterator algorithm + + :param alg: The iterator algorithm + :type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels" + """ + iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg) + + # Check if the iterator algorithm is valid + if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop: + raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.") + + self._iterator_algorithm = iterator_alg + + def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm: + """ + Propose a valid iterator algorithm based on problem size and alignment + """ + if self.conv_kind == ConvKind.Fprop: + # Check whether the fixed channel is applicable + if problem_size.C == alignment_a: + return IteratorAlgorithm.FixedChannels + elif (problem_size.C % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32): + return IteratorAlgorithm.Optimized + else: + return IteratorAlgorithm.Analytic + elif self.conv_kind == ConvKind.Dgrad: + if (problem_size.K % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32 and + problem_size.C % alignment_b == 0): + return IteratorAlgorithm.Optimized + else: + return IteratorAlgorithm.Analytic + elif self.conv_kind == ConvKind.Wgrad: + if (problem_size.K % alignment_a == 0 and + problem_size.C % alignment_b == 0): + return IteratorAlgorithm.Optimized + else: + return IteratorAlgorithm.Analytic + + def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool: + """ + Validate whether the user provide iterator algorithm works for the given problem size + """ + if self.conv_kind == ConvKind.Fprop: + if iterator_algorithm == IteratorAlgorithm.FixedChannels: + return problem_size.C == alignment_a + elif iterator_algorithm == IteratorAlgorithm.Optimized: + return (problem_size.C % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32) + elif iterator_algorithm == IteratorAlgorithm.FewChannels: + return problem_size.C % alignment_a == 0 + elif self.conv_kind == ConvKind.Dgrad: + if iterator_algorithm == IteratorAlgorithm.Optimized: + return (problem_size.K % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32 and + problem_size.C % alignment_b == 0) + elif self.conv_kind == ConvKind.Wgrad: + if iterator_algorithm == IteratorAlgorithm.Optimized: + return (problem_size.K % alignment_a == 0 and + problem_size.C % alignment_b == 0) + + return True + + # + # Stride Support Related + # + + def _propose_stride_support(self, stride): + if self.conv_kind == ConvKind.Dgrad: + if stride[0] == 1 and stride[1] == 1: + return StrideSupport.Unity + + return StrideSupport.Strided + + # + # Construct and Compilation + # + + def construct( + self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + iterator_algorithm: IteratorAlgorithm = None, + stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None, + epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation: + """ + Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current + kernel specification of the ``Conv2d`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param iterator_algorithm: the iterator algorithm used + :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm + :param stride_support: the stride support of dgrad + :type stride_support: cutlass_library.library.StrideSupport + :param swizzling_functor: the swizzling functor + :type swizzling_functor: cutlass_cppgen.swizzle + :param epilogue_functor: the epilogue functor + + :return: operation that was constructed + :rtype: cutlass_cppgen.backend.Conv2dOperation + """ + # Get alignment + alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A) + alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B) + alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C) + + tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A) + tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) + tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + + if tile_description is None: + if self.tile_description is not None: + tile_description = self.tile_description + else: + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] + tile_description = datatypes.td_from_profiler_op(op) + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self.tile_description = tile_description + + if iterator_algorithm is None: + # If the iterator algorithm is already set + if self.iterator_algorithm is not None: + iterator_algorithm = self.iterator_algorithm + else: + # Otherwise, we conservatively use the analytic iterator for correctness + iterator_algorithm = IteratorAlgorithm.Analytic + + if stride_support is None: + # If the stride support is already set + if self._stride_support is not None: + stride_support = self._stride_support + else: + # Otherwise, we assume strided + stride_support = StrideSupport.Strided + + if swizzling_functor is None: + # If the swizzling functor is already set + swizzling_functor = self._propose_swizzling_functor(stride=(2, 2)) + + if epilogue_functor is None: + if self.epilogue_functor is not None: + epilogue_functor = self.epilogue_functor + else: + epilogue_functor = self._create_epilogue_functor_activation(self._activation) + + # Reset the alignment of the epilogue functor + epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor) + + operation = Conv2dOperation( + conv_kind=self.conv_kind, + iterator_algorithm=iterator_algorithm, + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + stride_support=stride_support, + epilogue_functor=epilogue_functor, + swizzling_functor=swizzling_functor, + ) + + return operation + + def compile(self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + iterator_algorithm: IteratorAlgorithm = None, + stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None, + epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation: + """ + Emits and compiles the kernel currently specified. If ``tile_description`` and any + of the ``alignment`` parameters are set, the kernel will be chosen using this + tile description and alignments. Otherwise, a default tile description and alignment + will be used. + + ::param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param iterator_algorithm: the iterator algorithm used + :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm + :param stride_support: the stride support of dgrad + :type stride_support: cutlass_library.library.StrideSupport + :param swizzling_functor: the swizzling functor + :type swizzling_functor: cutlass_cppgen.swizzle + :param epilogue_functor: the epilogue functor + + :return: operation that was compiled + :rtype: cutlass_cppgen.backend.Conv2dOperation + """ + + self.operation = self.construct( + tile_description, alignment_A, alignment_B, alignment_C, + iterator_algorithm, stride_support, swizzling_functor, epilogue_functor) + + if print_module: + print(self.operation.rt_module.emit()) + + compiler.add_module([self.operation,]) + return self.operation + + # + # Run Related + # + + def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): + """ + Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception + is raised if it does not. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + """ + dtype, _ = datatypes.get_datatype_and_layout(tensor) + if dtype != ref_type: + raise Exception(f'Tensor {name} with type and layout {dtype} ' + f'does not match the expected type of {ref_type}.') + + def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation): + if self.conv_kind == ConvKind.Fprop: + input = A + weight = B + output = C + output_tensor = "C" + elif self.conv_kind == ConvKind.Dgrad: + output = A + weight = B + input = C + output_tensor = "A" + elif self.conv_kind == ConvKind.Wgrad: + output = A + input = B + weight = C + output_tensor = "A" + else: + raise Exception(f"Convolution kind {self.conv_kind} is not supported") + + N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV") + K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV") + _, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV") + + problem_size = Conv2DProblemSize( + N_, H_, W_, C_, + K_, R_, S_, C_, + padding[0], padding[1], + stride[0], stride[1], + dilation[0], dilation[1], + ConvMode.CrossCorrelation, + 1, 1 + ) + + if P_ != problem_size.P or Q_ != problem_size.Q: + raise Exception( + f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})") + + return problem_size + + def run(self, A=None, B=None, C=None, D=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), + alpha=None, beta=None, + split_k=("serial", 1), sync: bool = True, + print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + """ + Runs the kernel currently specified. If it has not already been, the kernel is emitted and + compiled. Tensors holding operands and outputs of the kernel are sourced either from the + ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` + parameters provided in the call, or from those + passed in on the construction of this object -- one of the two must be specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1) + :param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0) + :param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1) + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param split_k: a tuple (split_k_mode, split_k_slices) + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + + :return: arguments passed in to the kernel + :rtype: cutlass_cppgen.backend.Conv2dArguments + """ + if not stream: + stream = cuda.CUstream(0) + super().run_setup() + + A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") + B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") + C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") + D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + # handle the case when there is no C + if C is None: + if beta != 0: + raise Exception(f"With beta {beta} != 0, C has to be provided.") + else: + C = D + + # Construct problem size based on input + # It also verifies whether the A, B, C, D, stride, padding, and dilation are matching + problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation) + + # Propose stride support based on input + stride_support = self._propose_stride_support(stride) + + # Propose swizzling functor + swizzling_functor = self._propose_swizzling_functor(stride) + + shape_a = datatypes.get_tensor_shape(A, op="CONV") + shape_b = datatypes.get_tensor_shape(B, op="CONV") + shape_c = datatypes.get_tensor_shape(C, op="CONV") + + # Get the alignment + alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A") + alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B") + alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C") + + alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A) + alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B) + alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C) + + # Propose iterator algorithm based on input + if self._iterator_algorithm is None: + # Propose a default iterator algorithm based on the problem size + iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b) + else: + if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)): + iterator_algorithm = self._iterator_algorithm + else: + raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.") + + epilogue_args = [alpha, beta] + + if hasattr(self, "_activation_args"): + if isinstance(self._activation_args, list): + epilogue_args += self._activation_args + else: + epilogue_args.append(self._activation_args) + + if split_k[0] == "parallel" and split_k[1] > 1: + epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity) + else: + epilogue_functor = self.epilogue_functor + + # The alignment is determined by the iterator function (I believe) + self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support, + swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module) + + # Create reduction operation for parallel split-k + if split_k[0] == "parallel" and split_k[1] > 1: + epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor) + self.reduction_operation = ReductionOperation( + shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C, + element_accumulator=self._element_accumulator, + element_compute=self._element_accumulator, + epilogue_functor=epilogue_functor_reduction, + count=alignment_c + ) + if print_module: + print(self.reduction_operation.rt_module.emit()) + compiler.add_module([self.reduction_operation,]) + + arguments = Conv2dArguments( + operation=self.operation, problem_size=problem_size, + A=A, B=B, C=C, D=D, + output_op=self.operation.epilogue_type(*epilogue_args), + split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]), + split_k_slices=split_k[1], + stream=stream + ) + + self.operation.run(arguments) + + if split_k[0] == "parallel" and split_k[1] > 1: + implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind) + reduction_arguments = ReductionArguments( + self.reduction_operation, + problem_size=[implicit_gemm_size.m, implicit_gemm_size.n], + partitions=split_k[1], + workspace=arguments.ptr_D, + destination=D, + source=C, + output_op=self.reduction_operation.epilogue_type(*epilogue_args), + stream=stream + ) + self.reduction_operation.run(reduction_arguments) + + if sync: + if split_k[0] == "parallel" and split_k[1] > 1: + reduction_arguments.sync() + + # Free memory allocated by args because we are not + # calling `arguments.sync()` in this case (which will free memory) + arguments.free() + else: + arguments.sync() + + return arguments + + # + # Helper functions + # + @staticmethod + def output_size(input_size, weight_size, padding, stride, dilation): + problem_size = Conv2DProblemSize( + *input_size, + *weight_size, + padding[0], padding[1], + stride[0], stride[1], + dilation[0], dilation[1], + ConvMode.CrossCorrelation, + 1, 1 + ) + return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K) + + +# +# Easy to use interfaces for fprop, wgrad, and dgrad +# + +class Conv2dFprop(Conv2d): + def __init__( + self, + input=None, weight=None, C=None, output=None, alpha=1, beta=0, + element=None, + element_input=None, element_weight=None, element_C=None, element_output=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = input, weight, output + element_A, element_B, element_D = element_input, element_weight, element_output + super().__init__( + "fprop", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run( + self, input=None, weight=None, C=None, output=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + + if not stream: + stream = cuda.CUstream(0) + + A, B, D = input, weight, output + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) + + +class Conv2dDgrad(Conv2d): + def __init__( + self, + grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0, + element=None, + element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = grad_output, weight, grad_input + element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input + super().__init__( + "dgrad", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + # + if not stream: + stream = cuda.CUstream(0) + + A, B, D = grad_output, weight, grad_input + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) + + +class Conv2dWgrad(Conv2d): + def __init__( + self, + grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0, + element=None, + element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = grad_output, input, grad_weight + element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight + super().__init__( + "wgrad", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + if not stream: + stream = cuda.CUstream(0) + + A, B, D = grad_output, input, grad_weight + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f9b1ab43a1c45d0024e99e50e45813ba18866e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py @@ -0,0 +1,725 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running GEMMs. + + The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run + GEMM operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS GEMMs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # A, B, C, and D are torch/numpy/cupy tensor objects + plan = cutlass_cppgen.op.Gemm(A, B, C, D) + plan.run() + + + One can also use the interface by specifying data types of operands at construction + and using different tensor objects with these data types at runtime: + + .. highlight:: python + .. code-block:: python + + # The following is shorthand for: + # cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32, + # element_C=torch.float32, element_D=torch.float32, + # element_accumulator=torch.float32, + # layout=cutlass_cppgen.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + + A0 = torch.rand((128, 256), device='cuda') + B0 = torch.rand((256, 64), device='cuda') + C0 = torch.zeros((128, 64), device='cuda') + D0 = torch.zeros((128, 64), device.'cuda') + plan.run(A0, B0, C0, D0) + + A = torch.rand((32, 128), device='cuda') + B = torch.rand((128, 256), device='cuda') + C = torch.zeros((32, 256), device='cuda') + D = torch.zeros((32, 256), device.'cuda') + plan.run(A1, B1, C1, D1) + + The interface additionally enables one to decouple the compilation of the underlying CUTLASS + kernel from its execution: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.compile() + + # Do other work... + + plan.run(A0, B0, C0, D0) + + # Do other work... + + plan.run(A1, B1, C1, D1) + + Elementwise activation functions are easily fused to the GEMM via the interface: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.activation = cutlass_cppgen.epilogue.relu + + Operations can also be run asynchronously: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + args = plan.run() + + # Do other work... + + args.sync() +""" +from __future__ import annotations +from typing import Optional +from math import prod + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +from cutlass_library import ( + DataType, + DataTypeSize, + GemmUniversalMode, + KernelScheduleSuffixes, +) + +import cutlass_cppgen +from cutlass_cppgen import epilogue, swizzle +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor +from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass_cppgen.backend.library import TensorDescription, TileDescription +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils import check, datatypes + + +class Gemm(OperationBase): + """ + Constructs a ``Gemm`` object. + + The data types and layouts of operands A, B, and C, along with the data type of output D + and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime -- + these are not to be changed after a ``Gemm`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. The following + constructors are equivalent: + + .. highlight:: python + .. code-block:: python + + # Use F32 for A, B, C, D, and accumulation. All operands are row major. + + # Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts + # for operands to the same values. + Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + + # Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``. + Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32, + element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + + # Set the data types and elements from existing tensors. Note that one can use different tensors when + # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must + # have the same data type and layout as those passed in here). + # A, B, C, and D are row-major torch.Tensor objects of type torch.float32 + Gemm(A=A, B=B, C=C, D=D) + + # Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is + # the same as that for D, at present) + Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor, + layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor) + + # Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types + # and layouts will inherit those passed in via the generic ``element`` and ``layout`` + Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor, + element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + + The order of precedence for the setting of the data type and layout for a given operand/output is as follows: + 1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor + 2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those + 3) Otherwise, use the generic values (e.g., ``element``, ``layout``) + + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass_cppgen.DataType + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass_cppgen.DataType + :param layout: generic layout type to be used for operands A, B, C, and D + :type layout: cutlass_cppgen.LayoutType + :param element_A: data type to be used for operand A + :type element_A: cutlass_cppgen.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass_cppgen.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass_cppgen.DataType + :param layout_A: layout of operand A + :type layout_A: cutlass_cppgen.LayoutType + :param layout_B: layout of operand B + :type layout_B: cutlass_cppgen.LayoutType + :param layout_C: layout of operand C + :type layout_C: cutlass_cppgen.LayoutType + :param layout_D: layout of operand D + :type layout_D: cutlass_cppgen.LayoutType + """ + + def __init__( + self, A=None, B=None, C=None, D=None, + alpha=1.0, beta=0.0, element_accumulator=None, + element=None, layout=None, + element_A=None, element_B=None, element_C=None, element_D=None, + layout_A=None, layout_B=None, layout_C=None, + cc: int = None, kernel_cc: int = None + ): + super().__init__(cc=cc, kernel_cc=kernel_cc) + self.name = "gemm" + self.compiled = False + + elements = [] + layouts = [] + + # Check that at least one of the following is set for each tensor (illustrated assuming tensor A): + # ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout`` + for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D], + [layout_A, layout_B, layout_C, layout_C], + [A, B, C, D], + ["A", "B", "C", "D"]): + if elt is not None and tens is not None: + raise Exception(f'Must not specify both element_{name} and tensor {name}') + if lay is not None and tens is not None: + raise Exception(f'Must not specify both layout_{name} and tensor {name}') + if elt is None and tens is None and element is None: + raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') + if lay is None and tens is None and layout is None: + raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.') + + elt_to_set = None + lay_to_set = None + if tens is not None: + elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens) + else: + elt_to_set = elt if elt is not None else element + lay_to_set = lay if lay is not None else layout + + elements.append(datatypes.library_type(elt_to_set)) + layouts.append(lay_to_set) + + self._element_a, self._element_b, self._element_c, self._element_d = elements + self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts + + if element_accumulator is None: + self._element_accumulator = self._element_c + else: + self._element_accumulator = datatypes.library_type(element_accumulator) + + self.A = A + self.B = B + self.C = C + self.D = D + + self.alpha = alpha + self.beta = beta + + self.epilogue_functor = None + self.op_class = None + self._tile_description = None + + self._reset_operations() + + self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1 + + def _reset_operations(self, reset_epilogue: bool = True): + # Set the default op class + datatype_comb = (self._element_a, self._element_b, self._element_accumulator) + layout_comb = (self._layout_a, self._layout_b) + + self.possible_op_classes = self.options.supporting_opclasses( + self._element_a, self._element_b, self._element_accumulator, + self._layout_a, self._layout_b, self._math_operation) + + if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.TensorOp + elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.Simt + else: + if self._math_operation is not None: + math_op_str = f' and math operation {self._math_operation}' + else: + math_op_str = '' + + raise Exception(f'No kernel configuration found for supported data type and layout ' + f'combination {datatype_comb}x{layout_comb}{math_op_str}') + + if reset_epilogue: + self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity) + + @property + def swizzling_functor(self): + """ + Returns the type of the swizzling functor currently being used by the GEMM + + :return: swizzing functor type + """ + return self._swizzling_functor + + @swizzling_functor.setter + def swizzling_functor(self, swizzling_functor): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK: + if self.op_class == cutlass_cppgen.OpcodeClass.Simt: + raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp') + + if self.current_cc in [90, 100, 101, 103]: + raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+') + self._swizzling_functor = swizzling_functor + + # + # Tile description Related + # + + @property + def tile_description(self) -> TileDescription: + """ + Returns the tile description + """ + return self._tile_description + + @tile_description.setter + def tile_description( + self, td=None): + """ + Set the tile description + + :param td: tile description + :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys + { + "threadblock_shape": [int, int, int], + "warp_count": [int, int, int], + "stages": int, + "instruction_shape": [int, int, int] (optional), + "cluster_shape": [int, int, int] (optional) + } + """ + if td is None: + return + if isinstance(td, dict): + if self._tile_description is None: + op = self.possible_operations.default_operation(self._math_operation) + self._tile_description = datatypes.td_from_profiler_op(op) + td = self._tile_description.clone_and_update(td) + + valid, msg = self._valid_tile_description(td) + if valid: + self._tile_description = td + else: + raise Exception(msg) + + def _valid_tile_description(self, td: TileDescription) -> tuple: + """ + Checks whether the provided tile description is valid for the given compute capability. At present, + this checks the following: + + - Does the tile description use a number of stages supported by the compute capability in question? + - Does the tile size requested fit within shared memory? + - Are cluster dimensions outside the valid range requested for a given architecture (e.g., + more non-unit cluster dimensions for pre-SM90 architectures)? + - Is the kernel schedule being used supported on the architecture in question? + + :param td: tile description to validate + :type td: cutlass_cppgen.backend.TileDescription + :return: tuple in which the first element is a bool indicating that the tile description is valid + and the second element is a string providing an optional error message. + :rtype: tuple + """ + valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d) + if not valid: + return (valid, msg) + + valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) + if not valid: + return (valid, msg) + + valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler) + + if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0: + valid = False + msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103" + + return valid, msg + + def tile_descriptions(self) -> list: + """ + Returns a list of valid tile descriptions for the operations + + :returns: list of valid tile descriptions for the operations + :rtype: list + """ + tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations] + if self._math_operation is not None: + tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation] + return tds + + def construct( + self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal: + """ + Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current + kernel specification of the ``Gemm`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + + :return: operation that was constructed + :rtype: cutlass_cppgen.backend.GemmOperationUniversal + """ + alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) + alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) + alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A) + alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B) + + tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A) + tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) + + if alignment_C is None: + alignment_C = max(self.possible_operations.alignments("C")) + if self._element_c != DataType.void: + alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C) + + if tile_description is None: + if self._tile_description is None: + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] + tile_description = datatypes.td_from_profiler_op(op) + + # The selected op may have lower alignment than that determined above, so we must + # reset alignment here. + alignment_C = op.C.alignment + else: + tile_description = self._tile_description + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self._tile_description = tile_description + + tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) + + operation = GemmOperationUniversal( + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + epilogue_functor=self.epilogue_functor, + swizzling_functor=self._swizzling_functor, + ) + + return operation + + def compile(self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal: + """ + Emits and compiles the kernel currently specified. If ``tile_description`` and any + of the ``alignment`` parameters are set, the kernel will be chosen using this + tile description and alignments. Otherwise, a default tile description and alignment + will be used. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param print_module: whether to print the emitted C++ code + :type print_module: bool + + :return: operation that was compiled + :rtype: cutlass_cppgen.backend.GemmOperationUniversal + """ + self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C) + + if print_module: + print(self.operation.rt_module.emit()) + + compiler.add_module([self.operation,]) + return self.operation + + def _verify_rank(self, tensor): + """ + Verifies that ``tensor`` has rank greater than 1 + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + """ + if len(tensor.shape) < 2: + raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}") + + def _get_batch_count(self, A, B, C, D) -> int: + """ + Returns the batch count specified by the tensors A, B, C, and D and verifies that these + tensors match in batch size. Presence of a batch dimension is detected by one of the + tensors being rank 3. If a batch dimension is present, it must be present in one of + operands A, B, or C (but need not be in all), and must be present in D. + + :param A: tensor A + :type A: numpy/cupy/torch array/tensor object + :param B: tensor B + :type B: numpy/cupy/torch array/tensor object + :param C: tensor C + :type C: numpy/cupy/torch array/tensor object + :param D: tensor D + :type D: numpy/cupy/torch array/tensor object + + :return: tuple of batch count dimensions + :rtype: tuple + """ + A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1 + B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1 + + if 1 not in [A_batch, B_batch]: + if A_batch != B_batch: + raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}") + return max(A_batch, B_batch) + + def _get_batch_stride(self, tensor) -> int: + """ + Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0. + + :param tensor: tensor object to process + :type tensor: numpy/cupy/torch array/tensor object + + :return: stride between each matrix in the batch + :rtype: int + """ + if tensor is not None and len(tensor.shape) > 2: + return tensor.shape[-2] * tensor.shape[-1] + else: + return 0 + + def _get_problem_args(self, A, B, C, D) -> tuple: + """ + Returns the problem size and GEMM universal mode to use for the + given operands. + + :param A: tensor A + :type A: numpy/cupy/torch array/tensor object + :param B: tensor B + :type B: numpy/cupy/torch array/tensor object + :param C: tensor C + :type C: numpy/cupy/torch array/tensor object + :param D: tensor D + :type D: numpy/cupy/torch array/tensor object + + :return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int) + :rtype: tuple + """ + M, K = A.shape[-2:] + N = B.shape[-1] + mode = GemmUniversalMode.Gemm + + batch_count = self._get_batch_count(A, B, C, D) + returned_batch_count = batch_count + + # If we are running a batched GEMM in which there is a nonzero batch stride + # only for A, then we can fold the batched dimension of A into the M dimension + # (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A + # and C are row major. A similar operation can be performed if only B has a nonzero + # batch dimension + if batch_count > 1: + A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor + B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor + C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor + + # Consider a Tensor to be batched if its rank is > 2 and + # the product of the modes beyond rank 2 equals our pre-determined batch size. + batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count) + + if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row: + M *= batch_count + returned_batch_count = 1 + elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row: + N *= batch_count + returned_batch_count = 1 + else: + mode = GemmUniversalMode.Batched + + return GemmCoord(M, N, K), mode, returned_batch_count + + def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): + """ + Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception + is raised if it does not. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param ref_layout: layout for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + """ + dtype, layout = datatypes.get_datatype_and_layout(tensor) + if dtype != ref_type or layout != ref_layout: + try: + # Attempt to transpose the tensor to fit the desired layout + tensor = tensor.transpose(-1, -2) + except: + raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) ' + f'does not match the expected type and ' + f'layout of ({ref_type}, {ref_layout}) and transpose failed.') + + def run(self, A=None, B=None, C=None, D=None, + alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None, + stream: Optional[cuda.CUstream] = None) -> GemmArguments: + """ + Runs the kernel currently specified. If it has not already been, the kernel is emitted and + compiled. Tensors holding operands and outputs of the kernel are sourced either from the + ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` + parameters provided in this call, or from those + passed in on the construction of this object -- one of the two must be specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + + :return: arguments passed in to the kernel + :rtype: cutlass_cppgen.backend.GemmArguments + """ + if not stream: + stream = cuda.CUstream(0) + super().run_setup() + A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") + B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") + C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") + D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + is_void_c = self._element_c == DataType.void + + self._verify_rank(A) + self._verify_rank(B) + if not is_void_c: + self._verify_rank(C) + self._verify_rank(D) + + alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") + alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") + + # Set C alignment based on D.shape so as to correctly get an alignment with void-C + # kernels, for which `C` is None. + alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C") + self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, print_module=print_module) + + problem_size, mode, batch_count = self._get_problem_args(A, B, C, D) + + if mode == GemmUniversalMode.Gemm or batch_count == 1: + kwargs = {'split_k_slices': 1} + else: + kwargs = { + 'batch': batch_count, + 'batch_strides': { + 'A': self._get_batch_stride(A), + 'B': self._get_batch_stride(B), + 'C': self._get_batch_stride(C), + 'D': self._get_batch_stride(D) + } + } + + kwargs['stream'] = stream + + if isinstance(self.epilogue_functor, EpilogueFunctorVisitor): + output_op = self.operation.epilogue_type(visitor_args) + else: + output_op = self.operation.epilogue_type(alpha, beta) + + arguments = GemmArguments( + operation=self.operation, problem_size=problem_size, + A=A, B=B, C=C, D=D, + output_op=output_op, + gemm_mode=mode, + **kwargs + ) + + self.operation.run(arguments) + + if sync: + arguments.sync() + + return arguments diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py new file mode 100644 index 0000000000000000000000000000000000000000..59f90535c29a816541bc1a2155fea35afd1c94fd --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py @@ -0,0 +1,269 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running GEMMs. + + The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run + grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS grouped GEMMs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects + plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1]) +""" +from __future__ import annotations +from typing import Optional +from cutlass_library import DataTypeSize + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +from cutlass_cppgen.backend.gemm_operation import ( + GemmGroupedArguments, + GemmOperationGrouped, +) +from cutlass_cppgen.backend.library import ( + SchedulerMode, + TensorDescription, + TileDescription, +) +from cutlass_cppgen.op.gemm import Gemm +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils import check, datatypes + + +class GroupedGemm(Gemm): + """ + Constructs a ``GroupedGemm`` object. + + The data types and layouts of operands A, B, and C, along with the data type of output D + and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime -- + these are not to be changed after a ``GroupedGemm`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. Please see the constructor + for ``Gemm`` for examples of these. + + :param cc: compute capability of device to generate kernels for + :type cc: int + :param A: tensor representing data type and layout of operands A + :param B: tensor representing data type and layout of operands B + :param C: tensor representing data type and layout of operands C + :param D: tensor representing data type and layout of operands D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass_cppgen.DataType + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass_cppgen.DataType + :param layout: generic layout type to be used for operands A, B, C, and D + :type layout: cutlass_cppgen.LayoutType + :param element_A: data type to be used for operand A + :type element_A: cutlass_cppgen.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass_cppgen.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass_cppgen.DataType + :type layout_A: layout of operand A + :param layout_A: cutlass_cppgen.LayoutType + :type layout_B: layout of operand B + :param layout_B: cutlass_cppgen.LayoutType + :type layout_C: layout of operand C + :param layout_C: cutlass_cppgen.LayoutType + :type layout_D: layout of operand D + :param layout_D: cutlass_cppgen.LayoutType + """ + + def __init__( + self, A=None, B=None, C=None, D=None, + alpha=1.0, beta=0.0, element_accumulator=None, + element=None, layout=None, + element_A=None, element_B=None, element_C=None, element_D=None, + layout_A=None, layout_B=None, layout_C=None, + cc: int = None, + ): + super().__init__( + A=A, B=B, C=C, D=D, + alpha=alpha, beta=beta, + element_accumulator=element_accumulator, + element=element, layout=layout, + element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_D, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + cc=cc + ) + + # Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80 + if self.current_cc in [90, 100, 101, 103]: + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + + self.name = "grouped_gemm" + + @Gemm.swizzling_functor.setter + def swizzling_functor(self, swizzling_functor): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + raise Exception('Grouped GEMM does not currently support different swizzling functors') + + def construct(self, tile_description: TileDescription = None, + alignment_A: int = None, + alignment_B: int = None, + alignment_C: int = None) -> GemmOperationGrouped: + """ + Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current + kernel specification of the ``Gemm`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + + :return: operation that was constructed + :rtype: cutlass_cppgen.backend.GemmOperationGrouped + """ + alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A"))) + alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B"))) + alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C"))) + + self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) + + tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A) + tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) + tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + + if tile_description is None: + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] + tile_description = datatypes.td_from_profiler_op(op) + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self.tile_description = tile_description + + operation = GemmOperationGrouped( + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + epilogue_functor=self.epilogue_functor, + swizzling_functor=self._swizzling_functor, + precompute_mode=SchedulerMode.Device) + + return operation + + def run(self, A, B, C, D, + alpha=None, beta=None, sync: bool = True, + print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments: + """ + Runs the kernel currently specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: list of tensors representing data type and layout of operand A + :type A: list + :param B: list of tensors representing data type and layout of operand B + :type B: list + :param C: list of tensors representing data type and layout of operand C + :type C: list + :param D: list of tensors representing data type and layout of operand D + :type D: list + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + + :return: arguments passed in to the kernel + :rtype: cutlass_cppgen.backend.GemmGroupedArguments + """ + if not stream: + stream = cuda.CUstream(0) + + super().run_setup() + + if len(A) != len(B) or len(A) != len(C) or len(A) != len(D): + raise Exception("Lengths of A, B, C, and D lists must be equal") + + problem_sizes = [] + As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4)) + for i in range(len(A)): + As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A") + Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B") + Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C") + Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D") + problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1])) + + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As)) + alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs)) + alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs)) + self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, print_module=print_module) + + arguments = GemmGroupedArguments( + operation=self.operation, + problem_sizes=problem_sizes, + A=As, B=Bs, C=Cs, D=Ds, + output_op=self.operation.epilogue_type(alpha, beta), + stream=stream + ) + + self.operation.run(arguments) + + if sync: + arguments.sync() + + return arguments diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py new file mode 100644 index 0000000000000000000000000000000000000000..bebf07a7e5b83a1cf14cfecf19e90f730e305dce --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py @@ -0,0 +1,431 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) +""" + +from bisect import bisect_left + +from cutlass_library import ( + DataType, + DataTypeSize, + MathOperation, + OperationKind, + SharedMemPerCC +) + +import cutlass_cppgen +from cutlass_cppgen import get_option_registry +from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor +from cutlass_cppgen.backend.evt.passes.util import cc_map +from cutlass_cppgen.backend.utils.device import device_cc +from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity +from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs +from cutlass_cppgen.swizzle import get_swizzling_functors +from cutlass_cppgen.utils import datatypes, check + + +class OperationBase: + """ + Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) + """ + + def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm): + """ + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + :param operation_kind: class of operation that will be performed (e.g., GEMM, Conv) + :type operation_kind: cutlass_library.OperationKind + """ + self.operation_kind = operation_kind + self.cc = cc if cc is not None else device_cc() + self.specified_kernel_cc = kernel_cc is not None + self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc) + self.tile_description = None + self._math_operation = None + + self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind) + + if self.options is None: + raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}") + + # Default activation function: identity + self._activation = identity + + def _find_closest_cc(self, cc: int) -> int: + """ + Returns the closest CC in _generator_ccs less than or equal to `cc` + + :param cc: compute capability to query + :type cc: int + + :returns: closest CC in _generator_ccs less than or equal to `cc` + :rtype: int + """ + if cc in _generator_ccs: + return cc + + # Find closest CC lower than this CC + idx = bisect_left(_generator_ccs, cc) + if idx == 0: + raise Exception(f'No valid CC to fall back to for {cc}') + return _generator_ccs[idx-1] + + def activations(self) -> list: + """ + Returns possible activation functions that can be used + + :return: list of activation functions that can be used + :rtype: list + """ + return get_activations() + + def swizzling_functors(self) -> list: + """ + Returns possible swizzling functions that can be used + + :return: list of swizzling functions that can be used + :rtype: list + """ + return get_swizzling_functors() + + def _reset_options(self, cc: int): + """ + Resets the kernel options based on cc + + :param cc: compute capability to reset to + :type cc: int + """ + if cc != self.current_cc: + if cc not in _generator_ccs: + raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.') + self.current_cc = cc + self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind) + + def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name): + """ + Verifies the following properties: + 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``) + 2) If ``scalar`` is not ``None``, its datatype must match matches the current version + set by the plan (i.e., those in ``ref_dtype``) + + If either of these properties does not hold, an exception is raised. If these properties hold and + ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned. + + :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type scalar: numpy/cupy/torch scalar + :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in + :type ref_scalar: numpy/cupy/torch scalar + :param ref_dtype: data type for the scalar that this object was initialized to + :param name: identifier of the scalar to verify. Used in raising exceptions + :type name: str + + :return: valid scalar to use + :rtype: numpy/cupy/torch scalar + """ + if scalar is None: + if ref_scalar is None: + raise Exception(f"Scalar {name} must be set.") + return ref_scalar + if hasattr(scalar, "dtype"): + dtype = datatypes.library_type(scalar.dtype) + if dtype != ref_dtype: + raise Exception( + f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}." + ) + return scalar + + def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name): + """ + Verifies the following properties: + If ref_dtype is not void: + 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``) + 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions + set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``) + If ref_dtype is void: + Neither ``tensor`` nor ``ref_tensor`` are set + + If either of these properties does not hold, an exception is raised. If these properties hold and + ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in + :type ref_tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param ref_layout: layout for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + + :return: valid tensor object to use + :rtype: numpy/cupy/torch array/tensor object + """ + if ref_dtype == DataType.void: + if tensor is not None or ref_tensor is not None: + raise Exception("Operands with element DataType.void must not be provided a tensor") + return None + + if tensor is None: + if ref_tensor is None: + raise Exception(f"Tensor {name} must be set.") + return ref_tensor + + self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name) + return tensor + + @property + def opclass(self) -> cutlass_cppgen.OpcodeClass: + """ + Returns the opcode class currently in use + + :return: opcode class currently in use + :rtype: cutlass_cppgen.OpcodeClass + """ + return self.op_class + + @opclass.setter + def opclass(self, oc: cutlass_cppgen.OpcodeClass): + if isinstance(oc, str): + oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc) + if oc in self.possible_op_classes: + self.op_class = oc + else: + raise Exception( + f'Unsupported operation class {oc} for CC {self.cc} and data type combination ' + f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and ' + f'layout combination ({self._layout_a}, {self._layout_b}).') + + # Changing the op class also changes the possible operations available. Reset these. + self.possible_operations = self.options.operations( + self.op_class, self._element_a, self._element_b, + self._element_accumulator, self._layout_a, self._layout_b, self._math_operation) + + # Changing the op class changes the elements per access in the epilogue. Reset this. + if self.epilogue_functor is not None: + self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor) + + @property + def math_operation(self) -> cutlass_cppgen.MathOperation: + """ + Returns the math operation currently in use + + :return: math operation currently in use + :rtype: cutlass_cppgen.MathOperation + """ + return self._math_operation + + @math_operation.setter + def math_operation(self, mo: cutlass_cppgen.MathOperation): + if isinstance(mo, str): + mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo) + + if not self.specified_kernel_cc: + if self.current_cc in [90, 100, 101, 103]: + # CUTLASS 3.0 kernels do not use different math operations. If one is specified, we + # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + elif self.current_cc in [90, 100, 101, 103]: + raise Exception("CUTLASS 3.0 kernels do not use different math operations. " + "To use 2.x kernels with a specific math operation, do not set the `kernel_cc`" + "parameter when constructing the plan.") + + self._math_operation = mo + self._reset_operations() + + def _elements_per_access(self): + if self.op_class == cutlass_cppgen.OpcodeClass.Simt: + return 1 + elif self._element_c != DataType.void: + return 128 // DataTypeSize[self._element_c] + else: + return 128 // max(self.possible_operations.alignments("C")) + + def _create_epilogue_functor_activation(self, activation): + """ + Returns the epilogue functor with given activation function + """ + if self.epilogue_functor is None: + elements_per_access = self._elements_per_access() + else: + elements_per_access = self.epilogue_functor.epilogue_vector_length + + if not self.specified_kernel_cc: + if self.current_cc in [90, 100, 101, 103] and activation != identity: + # CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation, + # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + if self._element_c != self._element_d: + raise Exception("CUTLASS 2.x kernels require element C to be the same as element D") + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None): + # SM80 fallback kernels are currently used. Since an identity activation is requested, + # we can switch back to using SM90 kernels. + self._reset_options(self.cc) + self._reset_operations(reset_epilogue=False) + else: + if self.current_cc in [90, 100, 101, 103] and activation != identity: + raise Exception("Epilogues with elementwise fusion are not currently supported " + "in the Python interface for 3.x kernels. To use 2.x kernels " + "with fused elementwise epilogues, do not set the `kernel_cc` " + "parameter when constructing the plan.") + + return get_activation_epilogue( + activation, + self._element_d, + elements_per_access, + self._element_accumulator, + self._element_accumulator, + ) + + def _reset_epilogue_functor_activation(self, activation): + """ + Set the epilogue functor based on the provided activation function + """ + self.epilogue_functor = self._create_epilogue_functor_activation(activation) + + def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor): + """ + Reset the alignment of the current epilogue functor based on alignment C + """ + if isinstance(epilogue_functor, EpilogueFunctorVisitor): + return epilogue_functor + + if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'): + # Identity epilogue does not have 'activation_functor' + activation = identity + else: + activation = epilogue_functor.activation_functor + + epilogue_functor = get_activation_epilogue( + activation, + self._element_d, + alignment, + self._element_accumulator, + self._element_accumulator, + ) + return epilogue_functor + + @property + def activation(self): + """ + Returns the type of the current activation function used + """ + if hasattr(self.epilogue_functor, "activation_functor"): + return self.epilogue_functor.activation_functor + else: + return identity + + @activation.setter + def activation(self, act): + """ + Sets the type of the activation function to use + Activation can come with a set of arguments + + :param act: type of activation function to use + :type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01) + + """ + if isinstance(act, tuple): + if isinstance(act[0], str): + act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0]) + else: + act_fn = act[0] + self._reset_epilogue_functor_activation(act_fn) + self._activation_args = act[1] + self._activation = act[0] + else: + if isinstance(act, str): + act = getattr(cutlass_cppgen.backend.epilogue, act) + self._reset_epilogue_functor_activation(act) + self._activation = act + + @property + def epilogue_visitor(self): + """ + Return the epilogue functor + """ + return self.epilogue_functor + + @epilogue_visitor.setter + def epilogue_visitor(self, visitor): + """ + Create the epilogue visitor + """ + self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor) + + # The epilogue_functor may consume too much shared memory + # Reset the possible operations + if self.cc not in [90, 100, 101, 103]: + # The shared memory is only a concern for sm90+ epilogue + # In sm80, the epilogue and mainloop share the shared memory + return + + datatype_comb = self.possible_operations.datatype_comb + layout_comb = self.possible_operations.layout_comb + new_possible_operations = KernelsForDataType(datatype_comb, layout_comb) + for operation in self.possible_operations.all_operations: + td = datatypes.td_from_profiler_op(operation) + # Filter invalid epilogue schedules + if cc_map[self.cc] == 90 and td.epilogue_schedule not in [ + cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized, + cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]: + continue + epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td) + + # Verify the maximum number of mainloop stages + mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm) + smem_capacity_bytes = SharedMemPerCC[self.cc] << 10 + mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage + if mainloop_stages < 2: + # Mainloop stages must >= 2 + continue + + new_possible_operations.add(operation) + if len(new_possible_operations.all_operations) == 0: + raise RuntimeError( + "The epilogue consumes too much shared memory. " + "No valid tile description is found in the generator.") + self.possible_operations = new_possible_operations + + + def run_setup(self): + """ + Steps that must be taken before caling `plan.run()` + """ + # Initialize the memory pool if, if not already done + cutlass_cppgen.get_memory_pool() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py new file mode 100644 index 0000000000000000000000000000000000000000..a718f9bb4432f1f51457661abe27e24ea818aba4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py @@ -0,0 +1,184 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for expressing shapes +""" + +from cutlass_library import ( + ConvMode, + ConvKind, + LayoutType +) +from cutlass_cppgen.backend.c_types import ( + Conv2DProblemSize_, + GemmCoord_, + GemmCoordBatched_ +) + + +class MatrixCoord: + def __init__(self, row, col): + self._row = row + self._col = col + + @property + def row(self): + return self._row + + @property + def column(self): + return self._col + + def leading_dimension(self, layout: LayoutType) -> int: + """ + Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord. + + :param layout: layout of matrix + :type layout: cutlass_library.LayoutType + + :returns: leading dimension + :rtype: int + """ + if layout == LayoutType.RowMajor: + return self._col + elif layout == LayoutType.ColumnMajor: + return self._row + else: + raise Exception(f'Unsupported layout for leading dimension calculation: {layout}') + + +class GemmCoord: + def __init__(self, m: int, n: int, k: int): + self._m = m + self._n = n + self._k = k + + @property + def m(self) -> int: + return self._m + + @property + def n(self) -> int: + return self._n + + @property + def k(self) -> int: + return self._k + + @property + def mk(self) -> MatrixCoord: + return MatrixCoord(self._m, self._k) + + @property + def mn(self) -> MatrixCoord: + return MatrixCoord(self._m, self._n) + + @property + def kn(self) -> MatrixCoord: + return MatrixCoord(self._k, self._n) + + @property + def ctype(self) -> GemmCoord_: + return GemmCoord_(self._m, self._n, self._k) + + def batched_ctype(self, batch_count: int) -> GemmCoordBatched_: + return GemmCoordBatched_(self._m, self._n, self._k, batch_count) + + +class Conv2DProblemSize: + def __init__( + self, n: int, h: int, w: int, c: int, + k: int, r: int, s: int, c_: int, + pad_h: int, pad_w: int, stride_h: int, stride_w: int, + dilation_h: int, dilation_w: int, mode: ConvMode=ConvMode.CrossCorrelation, + split_k_slices: int=1, groups: int=1): + + self.N = n + self.H = h + self.W = w + self.C = c + self.K = k + self.R = r + self.S = s + self.pad_h = pad_h + self.pad_w = pad_w + self.stride_h = stride_h + self.stride_w = stride_w + self.dilation_h = dilation_h + self.dilation_w = dilation_w + self.mode = int(mode) + self.split_k_slices = split_k_slices + self.groups = groups + self.P = ((h + pad_h * 2 - r * dilation_h) // stride_h) + 1 + self.Q = ((w + pad_w * 2 - s * dilation_w) // stride_w) + 1 + + @property + def ctype(self) -> Conv2DProblemSize_: + return Conv2DProblemSize_(self) + + def implicit_gemm_size(self, kind: ConvKind): + if kind == ConvKind.Fprop: + return GemmCoord( + self.N * self.P * self.Q, + self.K, + self.R * self.S * self.C // self.groups + ) + elif kind == ConvKind.Dgrad: + return GemmCoord( + self.N * self.H * self.W, + self.C, + self.R * self.S * self.K + ) + elif kind == ConvKind.Wgrad: + return GemmCoord( + self.K, + self.R * self.S * self.C, + self.N * self.P * self.Q + ) + + @staticmethod + def from_sizes(input_size, weight_size): + K, R, S, _ = weight_size + pad_h = R // 2 + pad_w = S // 2 + stride_h = 1 + stride_w = 1 + dilation_h = 1 + dilation_w = 1 + return Conv2DProblemSize( + *input_size, + *weight_size, + pad_h, pad_w, + stride_h, stride_w, + dilation_h, dilation_w + ) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd9483415ea36716bf4643d27b8d92f3e9878a5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py @@ -0,0 +1,65 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Registry of swizzling functions +""" + +from cutlass_library import SwizzlingFunctor + + +IdentitySwizzle1 = SwizzlingFunctor.Identity1 +IdentitySwizzle2 = SwizzlingFunctor.Identity2 +IdentitySwizzle4 = SwizzlingFunctor.Identity4 +IdentitySwizzle8 = SwizzlingFunctor.Identity8 +HorizontalSwizzle = SwizzlingFunctor.Horizontal +ThreadblockSwizzleStreamK = SwizzlingFunctor.StreamK +StridedDgradIdentitySwizzle1 = SwizzlingFunctor.StridedDgradIdentity1 +StridedDgradIdentitySwizzle4 = SwizzlingFunctor.StridedDgradIdentity4 +StridedDgradHorizontalSwizzle = SwizzlingFunctor.StridedDgradHorizontal + + +_swizzling_functors = [ + IdentitySwizzle1, + IdentitySwizzle2, + IdentitySwizzle4, + IdentitySwizzle8, + HorizontalSwizzle, + ThreadblockSwizzleStreamK, + StridedDgradIdentitySwizzle1, + StridedDgradIdentitySwizzle4, + StridedDgradHorizontalSwizzle, +] + + +def get_swizzling_functors(): + return _swizzling_functors diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75d8416a15070ddcf2c6270248ccd9deff8e2137 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py @@ -0,0 +1,41 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.utils.check import ( + alignment_or_default, + calculate_smem_usage, + calculate_smem_usage_per_stage, + valid_cluster_shape, + valid_schedule, + valid_stage_count, + update_alignment, +) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py new file mode 100644 index 0000000000000000000000000000000000000000..108f268b4bc54ec0839afb5c1602ba63e5b98743 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py @@ -0,0 +1,262 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for checking constraints on kernels and calculating kernel attributes +""" + +import ctypes + +from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC + +import cutlass_cppgen +from cutlass_cppgen.backend.library import TileDescription + + +def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int: + """ + Returns the amount of shared memory in bytes consumed in a single stage of a kernel. + + :param td: tile description to compute shared memory of + :type td: TileDescription + :param operation_kind: identifier for the type of operation being performed + :type operation_kind: cutlass_library.OperationKind + + :return: number of bytes of shared memory consumed by a single stage + :rtype: int + """ + m, n, k = td.blackwell_threadblock_shape + if td.is_2sm: + m //= 2 + + if operation_kind == OperationKind.Gemm: + stage_barrier_bytes = 32 + return ( + (DataTypeSize[td.math_instruction.element_a] * m * k // 8) + + (DataTypeSize[td.math_instruction.element_b] * k * n // 8) + + stage_barrier_bytes + ) + else: + raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}") + + +def calculate_smem_usage(operation) -> int: + """ + Returns the amount of shared memory in bytes consumed by a kernel. + + :return: number of bytes of shared memory consumed by the operation + :return: int + """ + _per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind) + return _per_stage * operation.tile_description.stages + + +def valid_stage_count( + cc: int, + kernel_cc: int, + td: TileDescription, + element_C: cutlass_cppgen.DataType = None, + element_D: cutlass_cppgen.DataType = None, + verbose: bool = True) -> tuple: + """ + Checks whether a device with `cc` supports the number of stages within `tile_description`, both + based on raw limits on the number of stages and based on shared memory capacity + + :param cc: compute capability of device in question + :type cc: int + :param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS) + :type kernel_cc: int + :param td: tile description to check + :type td: TileDescription + :param element_C: data type of operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type of operand D + :type element_D: cutlass_cppgen.DataType + :param verbose: whether to log warnings + :type verbose: bool + + :return: tuple with the first element indicating whether the provided tile description is + valid for the provided device and the second element being an error message + :rtype: tuple + """ + if kernel_cc in [90, 100, 101, 103]: + if (td.stages is None or td.stages == 0): + # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically + # determines the stage count to use. Thus, all settings are valid in these scenarios. + return (True, "") + elif verbose: + cutlass_cppgen.logger.warning( + "Setting an explicit stage count for SM90 kernels currently may " + "result in compilation errors if the combination of tile shape, " + "stage count, and shared memory requirement of the epilogue exceeds " + "the available shared memory per SM.") + + if td.stages <= 0: + return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.") + + if cc < 80 and td.stages != 2: + return (False, f"Tile description has stage count of {td.stages}, " + f"but only 2 stages are supported on SM{cc}.") + + # The calculation below does not consider shared memory used by the epilogue and, thus, + # only catches cases in which the mainloop exceeds the device's shared memory capacity. + # This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the + # mainloop and epilogue is shared. + smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm) + smem_usage_mainloop = (smem_per_stage * td.stages) + smem_arch = SharedMemPerCC[cc] << 10 + if smem_usage_mainloop > smem_arch: + return ( False, + "Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n" + f"Details:\n" + f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and " + f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n" + f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.") + + return (True, "") + + +def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: + """ + Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`. + + :param cc: compute capability of device in question + :type cc: int + :param cluster_shape: dimensions of thread block cluster shape to check + :type cluster_shape: list + + :return: tuple with the first element indicating whether the provided cluster shape is + valid for the provided device and the second element being an error message + :rtype: tuple + """ + + if cc < 90 or cc in [120, 121]: + if cluster_shape != [1, 1, 1]: + return (False, + f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of " + f"{cluster_shape} for SM{cc}.") + else: + return (True, "") + + if len(cluster_shape) != 3: + return (False, + f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}") + + if cluster_shape[2] != 1: + return (False, + "CUTLASS kernels currently require the third dimension of cluster shape to be 1. " + f"Received cluster shape of {cluster_shape}.") + + return (True, "") + + +def valid_schedule( + cc: int, + kernel_schedule: cutlass_cppgen.KernelScheduleType, + epilogue_schedule: cutlass_cppgen.EpilogueScheduleType, + tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple: + """ + Checks that the kernel and epilogue schedules passed in are a valid combination for + a device of compute capability ``cc``. + + :param cc: compute capability of device in question + :type cc: int + :param kernel_schedule: kernel schedule type + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue schedule type + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param tile_scheduler: tile scheduler type + :type tile_scheduler: cutlass_cppgen.TileSchedulerType + + :return: tuple with the first element indicating whether the provided schedules are + valid for the provided device and the second element being an error message + :rtype: tuple + """ + kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto) + epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto) + tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default) + if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default): + return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)") + + if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)): + return (False, "Kernel and epilogue schedules must either both be auto or neither be auto") + + if not tile_scheduler_default: + cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative] + if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): + return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule") + return (True, "") + + +def alignment_or_default(alignment_provided: int, default_alignment: int) -> int: + """ + Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks + that `alignment_provided` does not exceed `default_alignment`. + + :param alignment_provided: alignment preference specified. Can be None. + :type alignment_provided: int + :param default_alignment: alignment to use if `alignment_provided` is None + :type default_alignment: int + + :return: alignment to use + :rtype: int + """ + if alignment_provided is not None: + if alignment_provided > default_alignment: + raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.") + return alignment_provided + + return default_alignment + + +def update_alignment(alignment_provided:int, default_alignment: int) -> int: + """ + Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks + that `alignment_provided` does not exceed `default_alignment`. + + :param alignment_provided: alignment preference specified. Can be None. + :type alignment_provided: int + :param default_alignment: alignment to use if `alignment_provided` is None + :type default_alignment: int + + :return: alignment to use + :rtype: int + """ + if alignment_provided is not None: + if alignment_provided > default_alignment: + if alignment_provided % default_alignment == 0: + return default_alignment + raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.") + return alignment_provided + + return default_alignment diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py new file mode 100644 index 0000000000000000000000000000000000000000..c03a834dc47871bebe618752e4775a0a7434ff78 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py @@ -0,0 +1,362 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for converting between frontend datatypes and CUTLASS datatypes +""" + +import cutlass_cppgen +from cutlass_library import ( + DataTypeSize, + MathOperation, + MathInstruction +) +from cutlass_cppgen.backend.library import ( + TileDescription, +) + +bfloat16_available = None +cupy_available = None +numpy_available = None +torch_available = None +_library_to_cupy_dict = None +_library_to_numpy_dict = None +_library_to_torch_dict = None +_torch_to_library_dict = None + + +def is_numpy_available(): + global numpy_available, _library_to_numpy_dict + if numpy_available is None: + try: + import numpy as np + + numpy_available = True + _library_to_numpy_dict = { + cutlass_cppgen.DataType.f16: np.float16, + cutlass_cppgen.DataType.f32: np.float32, + cutlass_cppgen.DataType.f64: np.float64, + cutlass_cppgen.DataType.s8: np.int8, + cutlass_cppgen.DataType.s32: np.int32, + } + except ImportError: + numpy_available = False + _library_to_numpy_dict = {} + return numpy_available + + +def is_numpy_tensor(inp) -> bool: + if is_numpy_available(): + import numpy as np + return isinstance(inp, np.ndarray) + return False + + +def numpy_library_type(inp) -> cutlass_cppgen.DataType: + if is_numpy_available(): + import numpy as np + if inp == np.float16: + return cutlass_cppgen.DataType.f16 + elif inp == np.float32: + return cutlass_cppgen.DataType.f32 + elif inp == np.float64: + return cutlass_cppgen.DataType.f64 + elif inp == np.int8: + return cutlass_cppgen.DataType.s8 + elif inp == np.int32: + return cutlass_cppgen.DataType.s32 + return None + + +def numpy_type(inp): + return _library_to_numpy_dict.get(inp, None) + + +def is_cupy_available(): + global cupy_available + if cupy_available is None: + try: + import cupy as cp + + cupy_available = True + _library_to_cupy_dict = { + cutlass_cppgen.DataType.f16: cp.float16, + cutlass_cppgen.DataType.f32: cp.float32, + cutlass_cppgen.DataType.f64: cp.float64, + cutlass_cppgen.DataType.s8: cp.int8, + cutlass_cppgen.DataType.s32: cp.int32, + } + except ImportError: + cupy_available = False + _library_to_cupy_dict = {} + return cupy_available + + +def is_cupy_tensor(inp) -> bool: + if is_cupy_available(): + import cupy as cp + return isinstance(inp, cp.ndarray) + return False + + +def cupy_library_type(inp) -> cutlass_cppgen.DataType: + if is_cupy_available(): + import cupy as cp + if inp == cp.float16: + return cutlass_cppgen.DataType.f16 + elif inp == cp.float32: + return cutlass_cppgen.DataType.f32 + elif inp == cp.float64: + return cutlass_cppgen.DataType.f64 + return None + + +def cupy_type(inp): + return _library_to_cupy_dict.get(inp, None) + + +def is_torch_available(): + global torch_available, _library_to_torch_dict, _torch_to_library_dict + if torch_available is None: + try: + import torch + + torch_available = True + _torch_to_library_dict = { + torch.half: cutlass_cppgen.DataType.f16, + torch.float16: cutlass_cppgen.DataType.f16, + torch.bfloat16: cutlass_cppgen.DataType.bf16, + torch.float: cutlass_cppgen.DataType.f32, + torch.float32: cutlass_cppgen.DataType.f32, + torch.double: cutlass_cppgen.DataType.f64, + torch.float64: cutlass_cppgen.DataType.f64, + torch.int8: cutlass_cppgen.DataType.s8, + torch.int32: cutlass_cppgen.DataType.s32, + torch.uint8: cutlass_cppgen.DataType.u8, + } + + _library_to_torch_dict = { + cutlass_cppgen.DataType.f16: torch.half, + cutlass_cppgen.DataType.f16: torch.float16, + cutlass_cppgen.DataType.bf16: torch.bfloat16, + cutlass_cppgen.DataType.f32: torch.float, + cutlass_cppgen.DataType.f32: torch.float32, + cutlass_cppgen.DataType.f64: torch.double, + cutlass_cppgen.DataType.f64: torch.float64, + cutlass_cppgen.DataType.s8: torch.int8, + cutlass_cppgen.DataType.s32: torch.int32, + cutlass_cppgen.DataType.u8: torch.uint8, + } + + def possibly_add_type(torch_type_name, cutlass_type): + # Only try adding the type if the version of torch being used supports it + if hasattr(torch, torch_type_name): + torch_type = getattr(torch, torch_type_name) + _torch_to_library_dict[torch_type] = cutlass_type + _library_to_torch_dict[cutlass_type] = torch_type + + possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3) + possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2) + + except ImportError: + torch_available = False + _torch_to_library_dict = {} + _library_to_torch_dict = {} + return torch_available + + +def is_torch_tensor(inp) -> bool: + if is_torch_available(): + import torch + return isinstance(inp, torch.Tensor) + return False + + +def torch_library_type(inp) -> cutlass_cppgen.DataType: + return _torch_to_library_dict.get(inp, None) + + +def torch_type(inp): + return _library_to_torch_dict.get(inp, None) + + +def is_bfloat16_available(): + global bfloat16_available + + if bfloat16_available is None: + try: + import bfloat16 + + bfloat16_available = True + except ImportError: + bfloat16_available = False + return bfloat16_available + + +def bfloat16_library_type(inp) -> cutlass_cppgen.DataType: + if is_bfloat16_available(): + import bfloat16 + if inp == bfloat16.bfloat16: + return cutlass_cppgen.DataType.bf16 + + +def bfloat16_type(inp): + if is_bfloat16_available(): + import bfloat16 + if inp == cutlass_cppgen.DataType.bf16: + return bfloat16.bfloat16 + + +def library_type(inp): + if inp in DataTypeSize: + return inp + + for cvt_fn in [ + bfloat16_library_type, + cupy_library_type, + numpy_library_type, + torch_library_type, + ]: + out = cvt_fn(inp) + if out is not None: + return out + + raise Exception(f"No available conversion from type {inp} to a library type.") + + +def _tensor_from_numpy(np_tensor): + dtype = library_type(np_tensor.dtype) + if np_tensor.flags.c_contiguous: + layout = cutlass_cppgen.LayoutType.RowMajor + elif np_tensor.flags.f_contiguous: + layout = cutlass_cppgen.LayoutType.ColumnMajor + return (dtype, layout) + + +def _tensor_from_torch(pt_tensor): + dtype = library_type(pt_tensor.dtype) + return (dtype, cutlass_cppgen.LayoutType.RowMajor) + + +def get_datatype_and_layout(tensor): + if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)): + return _tensor_from_numpy(tensor) + elif is_torch_tensor(tensor): + return _tensor_from_torch(tensor) + elif isinstance(tensor, float) or isinstance(tensor, int): + return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor) + else: + raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") + + +def get_tensor_shape(tensor, op="GEMM"): + if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)): + return tensor.shape + elif is_torch_tensor(tensor): + size = tensor.size() + if op == "CONV": + # PyTorch Tensors have shape NCHW + return (size[0], size[2], size[3], size[1]) + else: + return tuple(tensor.size()) + elif isinstance(tensor, float) or isinstance(tensor, int): + return (1,) + else: + raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") + + +_math_operation_value_map = {x.value: x for x in MathOperation} + + +def backend_math_operation(math_op: MathOperation): + if math_op.value not in _math_operation_value_map.keys(): + raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.") + return _math_operation_value_map[math_op.value] + + +def construct_backend_td(td: cutlass_cppgen.TileDescription, + kernel_schedule: cutlass_cppgen.KernelScheduleType, + epilogue_schedule: cutlass_cppgen.EpilogueScheduleType, + tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription: + mi = td.math_instruction + backend_mi = MathInstruction( + mi.instruction_shape, + mi.element_a, + mi.element_b, + mi.element_accumulator, + mi.opcode_class, + backend_math_operation(mi.math_operation) + ) + cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1] + return TileDescription(td.threadblock_shape, td.stages, td.warp_count, + backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler) + + +def td_from_profiler_op(op) -> TileDescription: + """ + Converts the profiler's TileDescription in ``op`` into the backend TileDescription + + :param op: profiler Operation + + :returns: backend TileDescription + :rtype: cutlass_cppgen.backend.TileDescription + """ + kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None + eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None + tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None + return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule) + + +def td_from_profiler_td(td: TileDescription) -> TileDescription: + """ + Converts the profiler's TileDescription into the backend TileDescription + + :param td: profiler TileDescription + :type td: cutlass_cppgen.TileDescription + + :returns: backend TileDescription + :rtype: cutlass_cppgen.backend.TileDescription + """ + return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None) + + +def to_camel_case(snake_str): + return "".join(x.capitalize() for x in snake_str.lower().split("_")) + + +def getattr_enum(obj, attr_name): + # The attr_name is under the snake_case + camel_attr = to_camel_case(attr_name) + if hasattr(obj, camel_attr): + return getattr(obj, camel_attr) + else: + raise Exception(f"Invalid option: {attr_name}") diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py new file mode 100644 index 0000000000000000000000000000000000000000..16f6a185040f4c2f6167c6191c9bee766a92b1b9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py @@ -0,0 +1,41 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +import importlib +from typing import Any + +def lazy_import(mod_name: str) -> Any: + class Lazy: + def __getattr__(self, name:str) -> Any: + module = importlib.import_module(mod_name) + return getattr(module, name) + + return Lazy() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..f53b1567978d17f2eaec0208d896aafb296f033f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py @@ -0,0 +1,196 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Profiler based on the cuda events +""" + +import re +import subprocess + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +import numpy as np + +from cutlass_cppgen import CUTLASS_PATH +from cutlass_cppgen.backend.library import DataTypeSize +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils.datatypes import is_numpy_tensor + + +class GpuTimer: + def __init__(self) -> None: + self.events = [ + cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], + cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], + ] + + def start(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + + (err,) = cuda.cuEventRecord(self.events[0], stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + + def stop(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + + (err,) = cuda.cuEventRecord(self.events[1], stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + pass + + def stop_and_wait(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + + self.stop(stream) + if stream: + (err,) = cuda.cuStreamSynchronize(stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + else: + (err,) = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + + def duration(self, iterations=1): + err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1]) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + return duration / float(iterations) + + +class CUDAEventProfiler: + def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None: + self.arguments = op.run(*args, **kwargs) + self.operation = op.operation + self.warmup_iterations = warmup_iterations + self.iterations = iterations + self.timer = GpuTimer() + + # + # Cutlass Python Interface Profiler + # + + def __call__(self): + for _ in range(self.warmup_iterations): + self.operation.run(self.arguments) + + self.timer.start() + for _ in range(self.iterations): + self.operation.run(self.arguments) + + self.timer.stop_and_wait() + runtime = self.timer.duration(self.iterations) + return runtime + + # + # CUTLASS Profiler + # + + def run_cutlass_profiler(self): + alpha = 1.0 + beta = 1.0 + + profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler" + kernel_name = self.operation.procedural_name() + verification_providers = "device" + provider = "cutlass" + problem_size = self.arguments.problem_size + + if "cutlass3x" in kernel_name: + # cutlass3x generator only have column-major output + layout_name = self.operation.layout_name_3x() + if layout_name[-1] == "t": + new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"]) + problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k) + kernel_name = kernel_name.replace(layout_name, new_layout_name) + + batch_count = self.arguments.batch_count + + cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \ + f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \ + f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\ + f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}" + + result = subprocess.getoutput(cmd) + + m = re.search(r"Runtime:\s+(?P\d+.\d+)", result) + runtime = float(m.group("runtime")) + + m = re.search(r"Bytes:\s+(?P\d+)", result) + bytes = int(m.group("bytes")) + + m = re.search(r"FLOPs:\s+(?P\d+)", result) + flops = int(m.group("flops")) + + # check if the problem size matches + assert bytes == self.bytes(problem_size, batch_count, beta) + assert flops == self.flops(problem_size, batch_count, beta) + + return runtime + + def bytes(self, problem_size, batch_count=1, beta=0.0): + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + + bytes = ( + (DataTypeSize[self.operation.A.element] * m // 8) * k + + (DataTypeSize[self.operation.B.element] * n // 8) * k + + (DataTypeSize[self.operation.C.element] * m // 8) * n + ) + + if beta != 0: + bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n + + bytes *= batch_count + + return bytes + + def flops(self, problem_size, batch_count=1, beta=0.0): + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + + flops_ = (m * n * k) * 2 * batch_count + + if beta != 0: + flops_ += m * n * batch_count * 2 + + return flops_ + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..534eef47d810eb9f17a9ba6dbbe2e0dff935eb3f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py @@ -0,0 +1,63 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import os +import sys + +from . import conv2d_operation +from . import conv3d_operation +from . import emit_kernel_listing +from . import gemm_operation + +if '-m' not in sys.argv: + # Do not import generator when running python -m cutlass_library.generator to + # avoid double-import warnings + from . import generator + +from . import library +from . import manifest +from . import rank_2k_operation +from . import rank_k_operation +from . import symm_operation +from . import trmm_operation +# Make enum types from library.py accessible via cutlass_library.* +from .library import * + +# Set up `source` to point to the path containing the CUTLASS source. +# Check first if the path contains a `source` subdirectory -- this will +# be the case when the package has been installed via pip. Otherwise, +# default to the root of CUTLASS. +install_source_path = os.path.join(__path__[0], 'source') +if os.path.isdir(install_source_path): + source_path = install_source_path +else: + source_path = os.path.join(__path__[0], '../..') diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..b674463a2c5795be8610883c4dc98a1e7123a01b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py @@ -0,0 +1,621 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Conv2d kernels +""" + +import enum +import logging +import os.path +import shutil +from string import Template + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes +except ImportError: + from library import * + from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### + +# +class Conv2dOperation: + # + def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \ + stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \ + group_mode = GroupMode.NoneGroup): + + self.operation_kind = OperationKind.Conv2d + self.arch = arch + self.tile_description = tile_description + self.conv_kind = conv_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.iterator_algorithm = iterator_algorithm + self.stride_support = stride_support + self.swizzling_functor = swizzling_functor + self.group_mode = group_mode + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + intermediate_type = '' + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.accumulator_type(): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + else: + inst_shape = '' + + return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ + inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm]) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + threadblock = self.tile_description.procedural_name() + + # grouped conv + if self.group_mode != GroupMode.NoneGroup: + group_conv_name = f"{GroupModeNames[self.group_mode]}_" + else: + group_conv_name = "" + + if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}" + else: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}" + + return SubstituteTemplate( + configuration_name, + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'alignment': "%d" % self.A.alignment, + 'group_conv_name': group_conv_name + } + ) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.configuration_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +class EmitConv2dInstance: + def __init__(self): + # Emitter for CUTLASS 3 convolution operations + self.conv3x_emitter = EmitConv3xInstance() + self.template = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} + >::Kernel; +""" + self.template_group_conv = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator}, + ${group_mode}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} + >::Kernel; +""" + self.template_depthwise_direct_conv = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>, + cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >, + + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ${threadblock_output_shape_n}, + ${threadblock_output_shape_p}, + ${threadblock_output_shape_q}>, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + cutlass::MatrixShape<${stride_r}, ${stride_s}>, + cutlass::MatrixShape<${dilation_r}, ${dilation_s}> + >::Kernel; +""" + + def arch_number_to_type(self, arch: int): + return f"cutlass::arch::Sm{arch}" + + def emit(self, operation): + _LOGGER.debug("*** EmitConv2dInstance::emit") + _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) + + if hasattr(operation, 'is_3x') and operation.is_3x: + _LOGGER.debug("*** CUTLASS 3 operation") + return self.conv3x_emitter.emit(operation) + + _LOGGER.debug("*** CUTLASS 2 operation") + + warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'conv_kind': ConvKindTag[operation.conv_kind], + 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], + 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), + 'stride_support': StrideSupportTag[operation.stride_support], + 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \ + MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + } + + if operation.group_mode == GroupMode.NoneGroup: + _LOGGER.debug("*** group_mode=NoneGroup") + return SubstituteTemplate(self.template, values) + + elif operation.group_mode == GroupMode.Depthwise: + _LOGGER.debug("*** group_mode=Depthwise") + values['group_mode'] = GroupModeTag[operation.group_mode] + # Setup other template params + values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0]) + values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1]) + values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2]) + + values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3]) + + values['filter_shape_r'] = str(operation.tile_description.filter_shape[0]) + values['filter_shape_s'] = str(operation.tile_description.filter_shape[1]) + + values['stride_r'] = str(operation.tile_description.stride[0]) + values['stride_s'] = str(operation.tile_description.stride[1]) + + values['dilation_r'] = str(operation.tile_description.dilation[0]) + values['dilation_s'] = str(operation.tile_description.dilation[1]) + + return SubstituteTemplate(self.template_depthwise_direct_conv, values) + + else: + _LOGGER.debug("*** group_mode=" + GroupModeTag[operation.group_mode]) + values['group_mode'] = GroupModeTag[operation.group_mode] + return SubstituteTemplate(self.template_group_conv, values) + +################################################################################################### +# +# Generator functions for all layouts +# +################################################################################################### + +# +def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128): + _LOGGER.debug("*** GenerateConv2dTensorOp") + + for tile in tile_descriptions: + for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: + + if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]): + + # + output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ + if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ + else [tile.math_instruction.element_accumulator,] + + for output_type in output_types: + A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a])) + B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b])) + C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type]))) + + manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator)) + +class EmitConv2dIncludes: + '''Emit includes that are specific to the operation.''' + + def __init__(self): + self.includes = ['conv2d_operation.h'] + self.emitter_3x = EmitConv3xIncludes() + + def operation_is_3x(self, operation) -> bool: + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def emit(self, operation) -> str: + if self.operation_is_3x(operation): + return self.emitter_3x.emit(operation) + + return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ + "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitConv2dConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name) + + self.instance_emitter = EmitConv2dInstance() + self.includes_emitter = EmitConv2dIncludes() + + self.header_template = """ +/* + Generated by conv2d_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +""" + + self.instance_template = """ +${stub_begin} +${operation_instance} +// Derived class +struct ${operation_name} : + public ${operation_name}_base { }; +${stub_end} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.configuration_header = """ + +namespace cutlass { +namespace library { + +// Initialize all instances +void initialize_${configuration_name}(Manifest &manifest) { +""" + + self.configuration_instance = """${stub_begin} + using Operation_${operation_name} = cutlass::conv::device::${kernel_name}< + ${operation_name}>; + + manifest.append(new cutlass::library::${operation_wrapper}< + Operation_${operation_name} + >( + "${operation_name}" + )); +${stub_end} +""" + + self.configuration_epilogue = "}\n" + + self.epilogue_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def operation_is_3x(self, operation): + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def __enter__(self): + """ + Open the configuration_file, and write the "header" C++ code to it. + + The "header" consists of a comment (that this is generated code, + so it should not be edited), and includes that are common + to all kinds of kernels. + """ + _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__enter__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + self.configuration_file = open(self.configuration_path, "w") + + self.configuration_file.write(SubstituteTemplate(self.header_template, { + 'configuration_name': self.configuration_name + })) + self.operations = [] + return self + + def emit(self, operation): + """ + Write three pieces of C++ code to the configuration_file + (that was opened by the __enter__ method above): + + 1. the header includes that are specific to the operation + (CUTLASS 2 vs. CUTLASS 3); + + 2. the "operation instance" (a "using" declaration ending in "_base"); and + + 3. the "operation name" (declaration and definition of a derived class + of the above operation instance). + + The "using" declaration turns a C++ class name, possibly namespace-qualified, + possibly also with angle brackets, into a C-style, easily demangled identifier. + """ + _LOGGER.debug('*** EmitConv2dConfigurationLibrary::emit') + _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name()) + self.operations.append(operation) + + self.configuration_file.write(self.includes_emitter.emit(operation)) + + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = '#endif // 0' + + self.configuration_file.write(Template(self.instance_template).substitute({ + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'operation_instance': self.instance_emitter.emit(operation), + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + def __exit__(self, exception_type, exception_value, traceback): + """ + Write the rest of the C++ code to the configuration_file, and close the file. + + The "rest of the C++ code" has the following components. + + 1. Configuration header: Open the namespace(s), and open the definition + of the "initialize_${configuration_name}" registration function + that registers the operation with the Manifest. + ("Registration" helps turn C++ compile-time polymorphism + (via template parameters) into a run-time choice of parameters.) + + 2. Configuration instance: In the body of the registration function, + make a "using" declaration Operation_${operation_name} for the + operation type (which uses operation_name as its template argument). + Then, tell the manifest about the operation via a "manifest.append" call. + The argument of the call is a new instance of + "SomethingOperation" + (replace Something with a specific name). + + 3. Configuration epilogue: Close the definition of the registration function. + + 4. Epilogue template: Close the namespace(s). + """ + + _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__exit__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + + self.configuration_file.write(SubstituteTemplate(self.configuration_header, { + 'configuration_name': self.configuration_name + })) + + for operation in self.operations: + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = "#endif // 0" + + if operation.group_mode == GroupMode.Depthwise: + kernel_name = 'DirectConvolution' + operation_wrapper = 'DirectConv2dOperation' + else: + kernel_name = 'ImplicitGemmConvolution' + operation_wrapper = 'Conv2dOperation' + if self.operation_is_3x(operation): + kernel_name = 'ConvUniversalAdapter' + operation_wrapper = 'ConvOperation3x' + + self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'kernel_name': kernel_name, + 'operation_wrapper': operation_wrapper, + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + self.configuration_file.write(self.configuration_epilogue) + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + + +################################################################################################### +################################################################################################### diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..b96b6db74224e52bd90b6e184a62624475385352 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py @@ -0,0 +1,482 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Conv3d kernels +""" + +import enum +import logging +import os.path +import shutil +from string import Template + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes +except ImportError: + from library import * + from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### + +# +class Conv3dOperation: + # + def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \ + stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + self.operation_kind = OperationKind.Conv3d + self.arch = arch + self.tile_description = tile_description + self.conv_kind = conv_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.iterator_algorithm = iterator_algorithm + self.stride_support = stride_support + self.swizzling_functor = swizzling_functor + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + intermediate_type = '' + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + else: + inst_shape = '' + + return "%s%s%s%s3d_%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], \ + inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm]) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + threadblock = "%dx%d_%dx%d" % ( + self.tile_description.threadblock_shape[0], + self.tile_description.threadblock_shape[1], + self.tile_description.threadblock_shape[2], + self.tile_description.stages + ) + + if self.stride_support == StrideSupport.Unity: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_unity_stride" + else: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}" + + return SubstituteTemplate( + configuration_name, + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + } + ) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.configuration_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +class EmitConv3dInstance: + def __init__(self): + # Emitter for CUTLASS 3 convolution operations + self.conv3x_emitter = EmitConv3xInstance() + self.template = """ + // Conv3d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultConv3d${conv_kind_name}< + ${element_a}, + cutlass::layout::TensorNDHWC, + ${element_b}, + cutlass::layout::TensorNDHWC, + ${element_c}, + cutlass::layout::TensorNDHWC, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + cutlass::arch::OpMultiplyAdd, + ${iterator_algorithm}, + ${stride_support} + >::Kernel; +""" + + def emit(self, operation): + _LOGGER.debug("*** EmitConv3dInstance::emit") + _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) + + if hasattr(operation, 'is_3x') and operation.is_3x: + _LOGGER.debug("*** CUTLASS 3 operation") + return self.conv3x_emitter.emit(operation) + + _LOGGER.debug("*** CUTLASS 2 operation") + + warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'conv_kind': ConvKindTag[operation.conv_kind], + 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], + 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), + 'stride_support': StrideSupportTag[operation.stride_support] + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### +# +# Generator functions for all layouts +# +################################################################################################### + +# +def GenerateConv3dTensorOp(manifest, tile_descriptions, min_cc, align = 128): + + for tile in tile_descriptions: + for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: + + if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]): + + # + output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ + if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ + else [tile.math_instruction.element_accumulator,] + + for output_type in output_types: + A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_a])) + B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_b])) + C = TensorDescription(output_type, LayoutType.TensorNDHWC, max(1, int(align / DataTypeSize[output_type]))) + + manifest.append(Conv3dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator)) + +class EmitConv3dIncludes: + '''Emit includes that are specific to the operation.''' + + def __init__(self): + self.includes = ['conv3d_operation.h'] + self.emitter_3x = EmitConv3xIncludes() + + def operation_is_3x(self, operation) -> bool: + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def emit(self, operation) -> str: + if self.operation_is_3x(operation): + return self.emitter_3x.emit(operation) + + return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ + "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitConv3dConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name) + + self.instance_emitter = EmitConv3dInstance() + self.includes_emitter = EmitConv3dIncludes() + + self.header_template = """ +/* + Generated by conv3d_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +""" + + self.instance_template = """ +${stub_begin} +${operation_instance} +// Derived class +struct ${operation_name} : + public ${operation_name}_base { }; +${stub_end} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.configuration_header = """ + +namespace cutlass { +namespace library { + +// Initialize all instances +void initialize_${configuration_name}(Manifest &manifest) { +""" + + self.configuration_instance = """${stub_begin} + using Operation_${operation_name} = cutlass::conv::device::${kernel_name}< + ${operation_name}>; + + manifest.append(new cutlass::library::${operation_wrapper}< + Operation_${operation_name} + >( + "${operation_name}" + )); +${stub_end} +""" + + self.configuration_epilogue = "}\n" + + self.epilogue_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def operation_is_3x(self, operation): + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def __enter__(self): + """ + Open the configuration_file, and write the "header" C++ code to it. + + The "header" consists of a comment (that this is generated code, + so it should not be edited), and includes that are common + to both the CUTLASS 2 and the CUTLASS 3 cases. + """ + _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__enter__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + self.configuration_file = open(self.configuration_path, "w") + + self.configuration_file.write(SubstituteTemplate(self.header_template, { + 'configuration_name': self.configuration_name + })) + self.operations = [] + return self + + def emit(self, operation): + """ + Write three pieces of C++ code to the configuration_file + (that was opened by the __enter__ method above): + + 1. the header includes that are specific to the operation + (CUTLASS 2 vs. CUTLASS 3); + + 2. the "operation instance" (a "using" declaration ending in "_base"); and + + 3. the "operation name" (declaration and definition of a derived class + of the above operation instance). + + The "using" declaration turns a C++ class name, possibly namespace-qualified, + possibly also with angle brackets, into a C-style, easily demangled identifier. + """ + _LOGGER.debug('*** EmitConv3dConfigurationLibrary::emit') + _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name()) + self.operations.append(operation) + + self.configuration_file.write(self.includes_emitter.emit(operation)) + + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = '#endif // 0' + + self.configuration_file.write(Template(self.instance_template).substitute({ + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'operation_instance': self.instance_emitter.emit(operation), + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + def __exit__(self, exception_type, exception_value, traceback): + """ + Write the rest of the C++ code to the configuration_file, and close the file. + + The "rest of the C++ code" has the following components. + + 1. Configuration header: Open the namespace(s), and open the definition + of the "initialize_${configuration_name}" registration function + that registers the operation with the Manifest. + ("Registration" helps turn C++ compile-time polymorphism + (via template parameters) into a run-time choice of parameters.) + + 2. Configuration instance: In the body of the registration function, + make a "using" declaration Operation_${operation_name} for the + operation type (which uses operation_name as its template argument). + Then, tell the manifest about the operation via a "manifest.append" call. + The argument of the call is a new instance of + "SomethingOperation" + (replace Something with a specific name). + + 3. Configuration epilogue: Close the definition of the registration function. + + 4. Epilogue template: Close the namespace(s). + """ + + _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__exit__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + + self.configuration_file.write(SubstituteTemplate(self.configuration_header, { + 'configuration_name': self.configuration_name + })) + + for operation in self.operations: + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = "#endif // 0" + + kernel_name = 'ImplicitGemmConvolution' + operation_wrapper = 'Conv3dOperation' + if self.operation_is_3x(operation): + kernel_name = 'ConvUniversalAdapter' + operation_wrapper = 'ConvOperation3x' + + self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'kernel_name': kernel_name, + 'operation_wrapper': operation_wrapper, + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + self.configuration_file.write(self.configuration_epilogue) + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + + +################################################################################################### +################################################################################################### diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..33d6da1a4675c0bbd07315717a7f5ba0ba0dc10c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py @@ -0,0 +1,250 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting CUTLASS >= 3 convolution kernels +""" + +import enum +import os.path +import shutil +import logging +from string import Template + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +class EmitConv3xInstance: + def __init__(self): + _LOGGER.debug("*** EmitConv3xInstance::__init__") + + # Define epilogue type first, so that the mainloop type + # can use it with StageCountAutoCarveout. + self.template = """ + +// CUTLASS >= 3 convolution ${conv_kind_name} kernel instance "${operation_name}" +using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, + ${opcode_class_epi}, + ${mma_tile_shape}, // mma tile shape + ${cluster_shape}, // cluster shape + ${epi_tile_mn}, + ${element_accumulator}, + ${element_compute}, + ${element_c}, ${layout_c}, 128 / cute::sizeof_bits_v<${element_c}>, + ${element_d}, ${layout_d}, 128 / cute::sizeof_bits_v<${element_d}>, + ${epilogue_schedule} + // , class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination + >::CollectiveOp; + +using ${operation_name}_mainloop = + typename cutlass::conv::collective::CollectiveBuilder< + ${arch}, + ${opcode_class_main}, + ${conv_kind}, // kFprop, kDgrad, or kWgrad + ${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>, + ${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>, + ${element_accumulator}, + ${mma_tile_shape}, // mma tile shape + ${cluster_shape}, // cluster shape + ${stages}, + ${kernel_schedule} + >::CollectiveOp; + +using ${operation_name}_problem_shape = cutlass::conv::ConvProblemShape<${conv_kind}, ${operation_name}_mainloop::NumSpatialDimensions>; + +// Unit tests call this "ConvKernel". +// Conv operator ${operation_name} +using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal< + ${operation_name}_problem_shape, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler} + >; +""" + + def arch_number_to_type(self, arch: int) -> str: + return f"cutlass::arch::Sm{arch}" + + def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str: + mma_m = cta_m + mma_n = cta_n + mma_k = cta_k + + if operation.arch >= 100: + # MmaTileShape (mma_m, mma_n, mma_k) is passed to kernel mainloop where + # mma_m = cta_m for 1sm version and mma_m = cta_m * 2 for 2sm version. + # If schedule is auto and cluster size is static and cta_m % 64 == 0 and cluster_m % 2 == 0, 2sm kernel version is allocated, + # otherwise 1sm kernel is allocated. + cta_m_per_mma_instruction = 1 + if "2sm" in operation.procedural_name() : + cta_m_per_mma_instruction = 2 + elif "1sm" in operation.procedural_name() : + cta_m_per_mma_instruction = 1 + elif operation.tile_description.cluster_shape[0] > 0 and operation.tile_description.cluster_shape[0] % 2 == 0 and cta_m % 64 == 0 : + cta_m_per_mma_instruction = 2 + mma_m = cta_m * cta_m_per_mma_instruction + + # For all three kinds of convolutions, the tile shape's K mode + # differs from GEMM in that needs to be wrapped in a Shape. + # For Wgrad convolutions specifically, + # the N tile shape also needs to be wrapped in a Shape. + m_template = 'cute::_${mma_m}' + if operation.conv_kind == ConvKind.Wgrad: + n_template = 'cute::Shape' + else: + n_template = 'cute::_${mma_n}' + k_template = 'cute::Shape' + + mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' + values = { + 'mma_m': mma_m, + 'mma_n': mma_n, + 'mma_k': mma_k + } + return Template(mma_tile_shape_template).substitute(values) + + def cluster_shape(self, operation) -> str: + m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)' + n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)' + k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)' + cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' + values = { + 'cluster_shape_m': operation.tile_description.cluster_shape[0], + 'cluster_shape_n': operation.tile_description.cluster_shape[1], + 'cluster_shape_k': operation.tile_description.cluster_shape[2], + } + return Template(cluster_shape_template).substitute(values) + + def stage_count(self, operation) -> str: + # stages == 0 tells builder to pick the number of stages automatically + namespace_prefix = 'cutlass::conv::collective::' + if operation.tile_description.stages > 0: + return f"{namespace_prefix}StageCount<{str(operation.tile_description.stages)}>" + else: + return f"{namespace_prefix}StageCountAutoCarveout" + + def emit(self, operation) -> str: + _LOGGER.debug("*** EmitConv3xInstance::emit") + _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) + + # Identify the operation as CUTLASS 3 by its is_3x field + if (not hasattr(operation, 'is_3x')) or (not operation.is_3x): + raise RuntimeError("operation must be a CUTLASS 3 operation") + + epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" + opcode_class_main = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class] + opcode_class_epi = opcode_class_main + + tile_shape = operation.tile_description.tile_shape + cluster_m = operation.tile_description.cluster_shape[0] + cluster_n = operation.tile_description.cluster_shape[1] + + cta_m, cta_n, cta_k = tile_shape + # account for static/dynamic cluster shapes + if operation.arch >= 100: + cta_m = cta_m // cluster_m if cluster_m > 0 else cta_m + cta_n = cta_n // cluster_n if cluster_n > 0 else cta_n + + warp_count = operation.tile_description.warp_count + epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule] + + # KernelScheduleTag and TileSchedulerTag both hard-code the + # namespace qualification of KernelScheduleAuto as + # "cutlass::gemm::collective::" (unless the tag is 'void'). + # + # For TileSchedulerTag, this namespace is fine, since CUTLASS 3 + # convolutions use the same tile schedulers (from the same + # cutlass::gemm::collective namespace) as GEMMs. + kernel_schedule = KernelScheduleTag[operation.kernel_schedule].replace('gemm::', 'conv::') + tile_scheduler = TileSchedulerTag[operation.tile_scheduler] + opcode_class = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class] + + values = { + 'operation_name': operation.procedural_name(), + 'conv_kind': ConvKindTag[operation.conv_kind], + 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'align_a': int(operation.A.alignment), + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'align_b': int(operation.B.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'align_c': int(operation.C.alignment), + 'element_d': DataTypeTag[operation.D.element], + 'layout_d': LayoutTag[operation.D.layout], + 'align_d': int(operation.D.alignment), + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': opcode_class, + 'arch': self.arch_number_to_type(operation.arch), + 'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k), + 'cluster_shape': self.cluster_shape(operation), + 'opcode_class_epi': opcode_class_epi, + 'opcode_class_main': opcode_class_main, + 'epi_tile_mn': epi_tile_mn, + 'stages': self.stage_count(operation), + 'kernel_schedule': kernel_schedule, + 'epilogue_schedule': epilogue_schedule, + 'tile_scheduler': tile_scheduler, + 'element_compute': DataTypeTag[operation.element_compute] + } + return Template(self.template).substitute(values) + +class EmitConv3xIncludes: + def __init__(self): + _LOGGER.debug("*** EmitConv3xIncludes::__init__") + self.includes = ['conv_operation_3x.hpp', + 'cutlass/conv/device/conv_universal_adapter.hpp', + 'cutlass/conv/kernel/conv_universal.hpp', + 'cutlass/conv/collective/collective_builder.hpp', + 'cutlass/epilogue/collective/collective_builder.hpp'] + + def emit(self, operation) -> str: + _LOGGER.debug("*** EmitConv3xIncludes::emit") + return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ + "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe52eb587ab1b5e4595739be5790151b00e0a70 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py @@ -0,0 +1,868 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +# +# +# \brief Generates the CUTLASS kernel listing with kernel filtering +# + +# + +############################################################################### +# Example usage: +# generator.py --operations all --generator-target kernel_listing \ +# --architectures "70;75;80" --kernels "*" --disable-cutlass-package-imports +############################################################################### + +import collections +import csv +import json +import math +import os + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +audit_csv_fields = [ + "KernelType", "KernelName", "Type_A", "Type_B", "Type_C", "Type_Acc", "Type_EpilogueScale", "Type_D", "Type_SFA", "Type_SFD", + "Layout_A", "Layout_B", "Layout_C", "Layout_D", + "Alignment_A", "Alignment_B", "Alignment_C", "Alignment_D", + "1SM/2SM", + "StreamK Enabled", "Support Runtime_Cluster_Shape", "Support Runtime_Input_Types", + "Test Counts" +] + +audit_csv_runtime_fields = [ + "KerneIndex", "KernelName", + "Inst_M", "Inst_N", "Inst_K", "Tile_M", "Tile_N", "Tile_K", + "Cluster_M", "Cluster_N", "Cluster_K", "Preferred_Cluster_M", "Preferred_Cluster_N", "Preferred_Cluster_K", "Fallback_Cluster_M", "Fallback_Cluster_N", "Fallback_Cluster_K", + "M", "N", "K", "L", "Alpha_val", "Beta_val", + "Runtime_Input_Types Enabled", "Runtime_Cluster_Shape Enabled" +] + +def hash_cutlass_string(input_string): + mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1') + + # Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1') + output = re.sub(mma_cluster_shape_pattern, "", input_string) + + return output + +def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b): + # Define a dictionary mapping the detected types to runtime values + datatype_map = { + 'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + } + + # Regular expression to detect all the keys in datatype_map + pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')') + + # Replace detected patterns using the dictionary + updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name) + + return updated_kernel_name + +# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k. +def get_kernel_features(operation, kernel_name, + dynamic_datatype, runtime_input_datatype): + numcta_inst = "2sm" if "2sm" in kernel_name else "1sm" + math_inst = operation.tile_description.math_instruction + + if dynamic_datatype: + dtype_name_A = runtime_input_datatype[0] + dtype_name_B = runtime_input_datatype[1] + else: + dtype_name_A = DataTypeNames[operation.A.element] + dtype_name_B = DataTypeNames[operation.B.element] + + layout_name_A = ShortLayoutTypeNames[operation.A.layout] + layout_name_B = ShortLayoutTypeNames[operation.B.layout] + layout_name_C = ShortLayoutTypeNames[operation.C.layout] + layout_name_D = ShortLayoutTypeNames[operation.D.layout] + + scale_factor_D_type = operation.ScaleFactorD.element if hasattr(operation, "ScaleFactorD") else DataType.void + scale_factor_A_type = getattr(operation, "ScaleFactorA", DataType.void) + audit_vals = [ + "BlockScaledGEMM" if math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp else "GEMM", + kernel_name, + dtype_name_A, + dtype_name_B, + DataTypeNames[operation.C.element], + DataTypeNames[operation.tile_description.math_instruction.element_accumulator], + DataTypeNames[operation.element_epilogue], + DataTypeNames[operation.D.element], + DataTypeNames[scale_factor_D_type], + DataTypeNames[scale_factor_A_type], + layout_name_A, + layout_name_B, + layout_name_C, + layout_name_D, + str(operation.A.alignment), + str(operation.B.alignment), + str(operation.C.alignment), + str(operation.D.alignment), + numcta_inst, + "Y" if 'stream_k' in kernel_name else "N", + ] + return audit_vals + +# This helper function reports other performance-related kernel parameters and those can be specified at runtime: cluster_shape, instruction shap, m/n/k and alpha/beta. +def get_kernel_params(operation, kernel_name, cluster_shape, fallback_cluster_shape, problem_shape, alpha, beta, dynamic_datatype, dynamic_cluster): + math_inst = operation.tile_description.math_instruction + audit_vals = [ + str(math_inst.instruction_shape[0]), + str(math_inst.instruction_shape[1]), + str(math_inst.instruction_shape[2]), + str(operation.tile_description.threadblock_shape[0]), + str(operation.tile_description.threadblock_shape[1]), + str(operation.tile_description.threadblock_shape[2]), + str(operation.tile_description.cluster_shape[0]), + str(operation.tile_description.cluster_shape[1]), + str(operation.tile_description.cluster_shape[2]), + str(cluster_shape[0]), + str(cluster_shape[1]), + str(cluster_shape[2]), + str(fallback_cluster_shape[0]), + str(fallback_cluster_shape[1]), + str(fallback_cluster_shape[2]), + str(problem_shape[0]), + str(problem_shape[1]), + str(problem_shape[2]), + str(problem_shape[3]), + str(alpha), + str(beta), + "Y" if dynamic_datatype else "N", + "Y" if dynamic_cluster else "N", + ] + return audit_vals + + +def _getSubOperationType(kernel): + + if kernel.operation_kind == OperationKind.Gemm: + return GemmKindNames[kernel.gemm_kind] + elif kernel.operation_kind == OperationKind.Conv2d: + return "conv_" + ConvKindNames[kernel.conv_kind] + elif kernel.operation_kind == OperationKind.Syrk: + return "syrk_" + SyrkKindNames[kernel.syrk_kind] + elif kernel.operation_kind == OperationKind.Trmm: + return "trmm_" + TrmmKindNames[kernel.trmm_kind] + elif kernel.operation_kind == OperationKind.Symm: + return "symm_" + SymmKindNames[kernel.symm_kind] + else: + raise Exception("Unsupported kernel type") + +def _get_inst_shape(math_instruction): + return "".join(str(x) for x in math_instruction.instruction_shape) + +def _is_simt_inst(math_instruction): + return _get_inst_shape(math_instruction) in ["111","114"] + +def _getInstType(input_precision, accumulate_precision, math_instruction): + + # inst_shape + inst_shape = _get_inst_shape(math_instruction) + + # input precision + if input_precision == "fp32" and inst_shape != "111": + inp = "tf32" + else: + inp = input_precision + + # Handle SIMT op types first + if _is_simt_inst(math_instruction): + + simt_input_precision_to_inst = { + "fp32": "FFMA", + "fp64": "DFMA", + "fp16": "HFMA", + "int8": "IDP4A", + } + inst = simt_input_precision_to_inst[input_precision] + + else: # Tensor op instructions + + if accumulate_precision == "cf64": + fp64_acc_map = { + MathOperation.multiply_add_complex_gaussian : "gz", + MathOperation.multiply_add_complex : "z", + } + acc = fp64_acc_map[math_instruction.math_operation] + else: + tensor_op_acc_map = { + "fp32" : "s", + "cf32" : "s", + "fp16" : "h", + "int32": "i", + "fp64" : "d", + } + acc = tensor_op_acc_map[accumulate_precision] + + inst = "{}{}{}".format(acc, inst_shape, inp) + + return inst +# TODO: Computes FLOps/Bytes for GEMM - revisit for conv +def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1): + assert not (batch_count > 1 and num_groups > 1) + + # TODO: adjust for sparsity + gmem_bytes = ( + (DataTypeSize[operation.A.element] * m // 8) * k + + (DataTypeSize[operation.B.element] * n // 8) * k + + (DataTypeSize[operation.C.element] * m // 8) * n + ) + + # TODO: complex-valued support + flops = 2 * (m * n * k) + + if bool(beta): + gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n + flops += 2 * m * n + + multiplier = max(batch_count, num_groups) + gmem_bytes *= multiplier + flops *= multiplier + + return flops / gmem_bytes + +def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode + ): + # For functional testing, we prefer to run reference computing on device if any + reference_device_archs = ["100a", "103a"] + run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False + profiler_flags_for_verification = "device" if run_reference_on_device else "host" + + # beta values for L0 and L1 + # TODO: randomize beta values for wider coverage + beta_values = [0.5] + + is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"]) + + is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch + + if (mode == "functional_L0") and is_supported_arch: + problem_waves = [0.5, 1.25, 2.5] + + # + # Dense Gemm + # + + sm100_mma_data_type_general = [ + 'gemm_f16_f16_f16_f16_f16', + 'gemm_f16_f16_f16_void_f16', + #'gemm_f16_f16_f32_f16_f16', + 'tf32gemm_f32_f32_f32_f32_f32', + 'bf16gemm_f32_f32_f32_f32_f32', + ] + + exclude_archs = arch not in ("103a") + if exclude_archs: + sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8') + + sm100_mma_data_type_runtime_dtype = [ + 'gemm.*f4_f4_f32_f32_f32', + 'gemm.*f6_f6_f32_f32_f32', + 'gemm.*f8_f8_f32_f32_f32', + ] + + sm100_mma_cluster_size = [ + '8x1x1', + '4x4x1', '2x1x1', + '0x0x1' # dynamic cluster + ] + + # Restrict to two layouts to reduce L0 build and test time. + sm100_mma_layouts = [ + 'tnt', + 'ntn' + ] + + # regex list must be in kernel procedural name order + sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + + sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + + # + # Block Scale Gemm + # + + block_scaled_data_type = [ + # runtime datatypes + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', + 'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2', + 'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2', + #'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', + 'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2', + ] + + block_scaled_tile_k = ['x128_', 'x256_'] + + sm103_block_scaled_data_type = [ + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', + ] + + sm103_block_scaled_tile_k = ['x768_'] + + block_scaled_cluster_size = [ + '4x4x1', '2x1x1', + '0x0x1' # dynamic cluster + ] + + block_scaled_layouts = ['tnt'] + # regex list must be in kernel procedural name order + block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + sm103_block_scaled_prefetch_policy = ['tmapf'] + sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" + sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" + + if arch in ["100a", "100f"]: + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})" + elif arch in ["101a", "101f", "110a", "110f"]: + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})" + elif arch in ["103a"]: + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})|" \ + f"({sm103_block_scaled_filter_regex_1sm})|" \ + f"({sm103_block_scaled_filter_regex_2sm})" + elif arch in ["120a", "120f", "121a", "121f"]: + + # blockscaled sm120_mma kernels + blockscaled_sm120_mma_kernel_cta_tiles = [ + [ '128x128' ] + ] + + # Restrict to two layouts to reduce L0 build and test time. + blockscaled_sm120_mma_layouts = [ 'tn' ] + filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*" + + problem_waves = [0.5, 1.25, 2.5] + + kernel_filter = f"({filter_regex_blockscaled_sm120_mma})" + else: + error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f" + raise Exception(error_message) + + elif mode == "functional_L1": + sm100_mma_cluster_size = [ + '0x0x1' # dynamic cluster + ] + # Restrict to two layouts to reduce L1 build and test time. + sm100_mma_layouts = ['tnt', 'ntn'] + sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + block_scaled_data_type = [ + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2', + 'ue8m0xmx8s26_ue8m0xmx8s26_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', + 'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2', + ] + + sm103_block_scaled_data_type = [ + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', + ] + + block_scaled_cluster_size = ['0x0x1'] + block_scaled_layouts = ['tnt'] + + # regex list must be in kernel procedural name order + block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})" \ + f"({sm103_block_scaled_filter_regex_1sm})|" \ + f"({sm103_block_scaled_filter_regex_2sm})" + # CTA tiles for sm120 MMA - only run one tile size to reduce build/test times + sm120_mma_kernel_cta_tiles = [ + # h1688, s1688, i16832, i8816 + [ '256x128' ], + # d884, c1688, + [ '128x128' ], + # c1688, z884 + [ '128x64' ], + # gz884 + [ '64x64' ] + ] + + # sm120 MMA instruction shapes, planar complex type excluded as they are not required + sm120_mma_instruction_shapes = [ + [ 'h1688gemm_(?!planar_complex)', + 's1688gemm_f16', + 's1688gemm_bf16', + 's1688gemm_tf32', + 'i16832gemm', + 'i8816gemm' ], + [ 'd884gemm', 'c1688tf32gemm' ] , + [ 'c1688gemm', + 'z884gemm' ], + [ 'gz884gemm'] + ] + + # It's not pretty, but not sure why different instructions support different tile sizes. + filter_regex_sm120_mma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[0], sm120_mma_kernel_cta_tiles[0]]]) + ").*" + filter_regex_sm120_mma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[1], sm120_mma_kernel_cta_tiles[1]]]) + ").*" + filter_regex_sm120_mma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[2], sm120_mma_kernel_cta_tiles[2]]]) + ").*" + filter_regex_sm120_mma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[3], sm120_mma_kernel_cta_tiles[3]]]) + ").*" + + filter_regex_sm120_mma = f"({filter_regex_sm120_mma_0})|({filter_regex_sm120_mma_1})|({filter_regex_sm120_mma_2})|({filter_regex_sm120_mma_3})" + + problem_waves = [0.5, 1.25, 2.5] + + if arch in ["120a", "120f", "121a", "121f"]: + kernel_filter = f"({filter_regex_sm120_mma})" + else: + kernel_filter = f"({filter_regex_sm100_mma})" + else: + raise ValueError() + + outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv") + + audit_file_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_SM{arch}_cutlass3x_gemm.csv") + + audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv") + + kernel_filter_re = re.compile(kernel_filter) + testcase_counter = 0 + kernels_emitted = 0 + kernels_total = 0 + + perf_json_list = [] + kernel_name_set = set() + + testlist_csv_fields = ["testcase", "metadata"] + testlist_csv_rows = [] + auditlist_csv_map = {} + auditlist_csv_params_map = {} + + kernel_features = {} + + for cc in manifest.operations[OperationKind.Gemm].keys(): + for kernel_name, operation_l in manifest.operations[OperationKind.Gemm][cc].items(): + assert(len(operation_l) == 1) + kernels_total += 1 + if len(kernel_filter_re.findall(kernel_name)) == 0: + continue + # Only test f16 I/O void C kernels in void C kernel set + # Exception: Use void C kernels for more accurate perf testing + if '_void_' in kernel_name and 'perf_' not in mode: + if 'f16_f16_f16_void_f16' not in kernel_name : + continue + + kernels_emitted += 1 + kernel_name_set.add(kernel_name) + hashed_kernel_name = hash_cutlass_string(kernel_name) + operation = operation_l[0] + + dynamic_cluster = (operation.tile_description.cluster_shape[0] == 0 + or operation.tile_description.cluster_shape[1] == 0) + + dynamic_datatype = "f8" in kernel_name or "f6" in kernel_name or "f4" in kernel_name + + runtime_input_datatypes = [None] + + if dynamic_datatype: + if "f4_f4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "f4_f6" in kernel_name: + runtime_input_datatypes = [['e2m1','e3m2']] + elif "f4_f8" in kernel_name: + runtime_input_datatypes = [['e2m1','e4m3']] + + elif "f6_f4" in kernel_name: + runtime_input_datatypes = [['e3m2','e2m1']] + elif "f6_f6" in kernel_name: + runtime_input_datatypes = [['e3m2','e3m2']] + elif "f6_f8" in kernel_name: + runtime_input_datatypes = [['e3m2','e4m3']] + + elif "f8_f4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + elif "f8_f6" in kernel_name: + runtime_input_datatypes = [['e4m3','e3m2']] + elif "f8_f8" in kernel_name: + runtime_input_datatypes = [ + # mask out those not covered in statically encoded test cases + # ['e5m2','e4m3'], + # ['e4m3','e5m2'], + ['e4m3','e4m3'] + ] + + # block scaled kernels + elif "ue8m0xf4_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "ue4m3xf4_ue4m3xf4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "ue8m0xf4_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m3']] + elif "ue8m0xf4_ue8m0xf8" in kernel_name: + runtime_input_datatypes = [['e2m1','e4m3']] + + elif "ue8m0xf6_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e2m3','e2m1']] + elif "ue8m0xf6_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e2m3','e2m3']] + elif "ue8m0xf8_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + + elif "ue8m0xf8_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + elif "ue8m0xf8_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m3']] + elif "ue8m0xf8_ue8m0xf8" in kernel_name: + runtime_input_datatypes = [['e4m3','e4m3']] + + if "bstensorop" in kernel_name or is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind): + profiler_flags_for_verification = "host" + + # reduce L1 test runtime if reference kernel is not running on device. + if mode == "functional_L1" and profiler_flags_for_verification == "host" : + problem_waves = [0.5, 2.5] + + + if dynamic_cluster: + if mode == "functional_L0": + runtime_cluster_shapes = [[1,1,1], [2,2,1]] + else: + runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]] + # reduce L1 test runtime if reference kernel is not running on device. + if profiler_flags_for_verification == "host": + runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1]] + cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape + else: + runtime_cluster_shapes = [operation.tile_description.cluster_shape] + cta_tile_shape_m = int(operation.tile_description.threadblock_shape[0] / operation.tile_description.cluster_shape[0]) + cta_tile_shape_n = int(operation.tile_description.threadblock_shape[1] / operation.tile_description.cluster_shape[1]) + cta_tile_shape_k = int(operation.tile_description.threadblock_shape[2] / operation.tile_description.cluster_shape[2]) + + alignment_a = operation.A.alignment + alignment_b = operation.B.alignment + alignment_c = operation.C.alignment + alignment_ab_max = max(alignment_a, alignment_b) + + layout3x = operation.layout_name_3x() + data_types = operation.datatype_name_3x() + + ctas_per_mma_instruction = 1 + if '_2sm' in kernel_name: + ctas_per_mma_instruction = 2 + valid_cluster_shapes = [] + + # Remove any cluster shapes that have cluster_m that is not divisible by 2 + for cs in runtime_cluster_shapes: + if cs[0] % 2 == 0: + valid_cluster_shapes.append(cs) + runtime_cluster_shapes = valid_cluster_shapes + + kernel_problem_waves = problem_waves + if mode == "functional_L0" or mode == "functional_L1": + # for functional testing, we want to perturb just a little from even shapes + # large K = 8 is chosen such that some kernels will warp around their smem buffers, and some will not + # -16 ensures that we are TMA aligned even for FP8/Int8 + min_k = alignment_ab_max if cta_tile_shape_k == alignment_ab_max else cta_tile_shape_k - alignment_ab_max + max_k = (cta_tile_shape_k*8) - alignment_ab_max + problem_shapes_k = [min_k, max_k] + sm_count = 16 + swizzle_sizes = [0] + # Larger k and less than half wave trigger streamk +separate reduction case to be generated + if 'stream_k' in kernel_name: + problem_shapes_k = [max_k, cta_tile_shape_k*32] + kernel_problem_waves = [0.125, 1.25, 2.5] + else: + raise ValueError + + if "void" in kernel_name: + beta_values = [0] + + alignment_shift_m = max(alignment_c, alignment_a) + alignment_shift_n = max(alignment_c, alignment_b) + + is_first_line = True + for index_waves, waves in enumerate(kernel_problem_waves): + for index_k, k in enumerate(problem_shapes_k): + for beta in beta_values: + for cluster_shape in runtime_cluster_shapes: + for runtime_input_datatype in runtime_input_datatypes: + for swizzle_size in swizzle_sizes: + grid_size = waves * sm_count + cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape) + if cluster_shape_m >= cluster_shape_n: + grid_m = cluster_shape_m + grid_n = grid_size / grid_m + grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1) + else: + grid_n = cluster_shape_n + grid_m = grid_size / grid_n + grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1) + + verification_required = False + if mode == "functional_L0" or mode == "functional_L1": + if '_void_' not in kernel_name: + verification_required = True + + m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max) + n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max) + k = int(k) + + # For functional testing, we want to perturb just a little from even shapes. + # Only do this if the perturbation does not cause one of the dimensions of the + # problem size to go to zero. This can occur for blockscaling kernels for which + # the alignment requirements for A and B can be quite large (e.g., 256). + if m > alignment_shift_m: + m -= alignment_shift_m + if n > alignment_shift_n: + n -= alignment_shift_n + + if '_n32t32_' in kernel_name: + continue + batch_count = 1 + if mode == "functional_L0" or mode == "functional_L1" : + if index_waves == 0 and index_k == 0 : + batch_count = 3 if mode == "functional_L0" else 5 + gemm_op = "gemm" + + grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind) + num_groups = 1 + if grouped: + gemm_op = "grouped_gemm" + num_groups = 3 # small to limit test time in host block-scaled reference kernels + batch_count = 1 + elif "bstensorop" in kernel_name: + gemm_op = "block_scaled_gemm" + elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind): + gemm_op = "blockwise_gemm" + + problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)] + + assert m > 0 and n > 0 and k > 0 + + # Emit per-testcase metadata for perf testing usage, eventually in perf database + metadata_dict = { + "input_params": { + 'problem_size_category' : problem_size_category, + 'operation' : _getSubOperationType(operation), + 'datatype' : data_types, + 'layout' : layout3x, + 'm' : m, + 'n' : n, + 'k' : k, + 'beta' : beta, + 'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups) + }, + "runtime_params": { + 'ctas_per_mma_instruction' : ctas_per_mma_instruction, + 'tilesize_m' : cta_tile_shape_m, + 'tilesize_n' : cta_tile_shape_n, + 'tilesize_k' : cta_tile_shape_k, + 'cluster_shape_m' : cluster_shape_m, + 'cluster_shape_n' : cluster_shape_n, + } + } + + cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m + cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n + cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k + + + if dynamic_datatype: + runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype) + metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a + metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b + + testcase_metadata = [ + f"cutlass_profiler --operation={gemm_op}" + + (f" --verification-providers=device --providers=cutlass" if profiler_flags_for_verification == "device" else " --mode=trace") + + f" --error-on-no-match --error-if-nothing-is-profiled" + + f" --kernels={kernel_name}" + + f" --m={str(m)}" + + f" --n={str(n)}" + + f" --k={str(k)}" + + (f" --num_groups={str(num_groups)}" if grouped else "") + + f" --cluster_m={str(cluster_shape_m)}" + + f" --cluster_n={str(cluster_shape_n)}" + + f" --cluster_k={str(cluster_shape_k)}" + + f" --cluster_m_fallback={str(cluster_m_fallback)}" + + f" --cluster_n_fallback={str(cluster_n_fallback)}" + + f" --cluster_k_fallback={str(cluster_k_fallback)}" + + f" --beta={str(beta)}" + + ("" if grouped else f" --batch_count={str(batch_count)}") + + f" --swizzle_size={str(swizzle_size)}" + + f" --verification-required={str(verification_required).lower()}" + ] \ + + output_dynamic_datatype = dynamic_datatype + if output_dynamic_datatype: + testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" + + f" --runtime_input_datatype_b={runtime_datatype_b}") + + testcase_metadata.append(json.dumps(metadata_dict)) + testlist_csv_rows.append(testcase_metadata) + testcase_counter += 1 + + alpha = 1.0 + + if dynamic_datatype: + hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b) + + # If kernel_name is new, initialize its feature set with defaults + if hashed_kernel_name not in kernel_features: + kernel_features[hashed_kernel_name] = { + "is_support_dynamic_cluster": False, + "is_support_dynamic_datatype": False, + } + + # Update features for the hashed kernel name + kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster + kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype + + if hashed_kernel_name not in auditlist_csv_params_map: + auditlist_csv_params_map[hashed_kernel_name] = [] + + audit_row_params = get_kernel_params( + operation, + hashed_kernel_name, + (cluster_shape_m, cluster_shape_n, cluster_shape_k), + (cluster_m_fallback, cluster_n_fallback, cluster_k_fallback), + (m, n, k, batch_count), + alpha, beta, + dynamic_datatype, dynamic_cluster + ) + + auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params) + + if hashed_kernel_name not in auditlist_csv_map: + audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype) + auditlist_csv_map[hashed_kernel_name] = audit_row + + with open(outfile_name, 'w') as testlist_csv: + csv_writer = csv.writer(testlist_csv, delimiter=',') + csv_writer.writerow(testlist_csv_fields) + csv_writer.writerows(testlist_csv_rows) + + with open(audit_file_name, 'w') as auditlist_csv: + csv_writer = csv.writer(auditlist_csv, delimiter=',') + csv_writer.writerow(audit_csv_fields) + for hashed_kernel_name, row in auditlist_csv_map.items(): + # Append the dynamic features as "Y" or "N" + dynamic_cluster_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] else "N" + dynamic_datatype_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] else "N" + test_count = len(auditlist_csv_params_map[hashed_kernel_name]) + csv_writer.writerow(row + [dynamic_cluster_flag, dynamic_datatype_flag, test_count]) + + with open(audit_file_params_name, 'w') as auditlist_csv: + csv_writer = csv.writer(auditlist_csv, delimiter=',') + csv_writer.writerow(audit_csv_runtime_fields) + for kernel_index, (hashed_kernel_name, rows) in enumerate(auditlist_csv_params_map.items(), start=1): + for i, row in enumerate(rows): + if i == 0: + csv_writer.writerow([kernel_index, hashed_kernel_name] + row) + else: + csv_writer.writerow(["", ""] + row) + + print(f"Generated a total of {testcase_counter} test cases for {kernels_emitted} kernels out of {kernels_total} total.") + + # Generate a newline separated list of kernel filters + assert(len(kernel_name_set) == kernels_emitted) + output_filter_enabled = True + if output_filter_enabled: + kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list") + with open(kernel_filter_outfile_name, "w") as file: + kernel_name_set = set(map(lambda x: x.replace("_epi_tma", ""), kernel_name_set)) + for kernel_name in kernel_name_set: + file.write(kernel_name + "\n") + + # Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and sm120_mma kernels in cutlass2.x generated together. + if mode == "functional_L0" or mode == "functional_L1": + # Sort the .csv file + outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv") + with open(outfile_name) as file: + data = file.readlines() + data.sort() + with open(outfile_name, 'w') as file: + for i in range(len(data)): + file.write(data[i]) + # Sort the kernel list + kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list") + with open(kernel_filter_outfile_name) as file: + data = file.readlines() + data.sort() + with open(kernel_filter_outfile_name, 'w') as file: + for i in range(len(data)): + file.write(data[i]) + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2449e769303b738212cdcd896c9f2793ca2632 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py @@ -0,0 +1,1613 @@ + +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting GEMM kernels +""" + +import collections +import enum +import functools +import logging +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### +# +# Data structure modeling a GEMM operation +# +################################################################################################### + +# +class GemmOperation: + # + def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, + kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, + tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False, + ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None, + ScaleFactorMVecSize = None, ScaleFactorNVecSize = None, ScaleFactorKVecSize = None): + + kinds_3x = { + GemmKind.Universal3x, + GemmKind.SparseUniversal3x, + GemmKind.BlockScaledUniversal3x, + GemmKind.GroupedUniversal3x, + GemmKind.GroupedBlockScaledUniversal3x, + GemmKind.BlockwiseUniversal3x, + GemmKind.GroupedBlockwiseUniversal3x, + } + self.is_3x = gemm_kind in kinds_3x + self.prefix = "3x" if self.is_3x else "" + self.operation_kind = OperationKind.Gemm + self.arch = arch + self.tile_description = tile_description + self.gemm_kind = gemm_kind + self.A = A + self.B = B + self.C = C + self.D = D + + if is_block_scaled(gemm_kind): + self.ScaleFactorA = ScaleFactorA + self.ScaleFactorB = ScaleFactorB + self.ScaleFactorD = ScaleFactorD["tensor"] + self.ScaleFactorVectorSize = ScaleFactorD["vector_size"] + + if is_blockwise(gemm_kind): + self.ScaleFactorMVecSize = ScaleFactorMVecSize + self.ScaleFactorNVecSize = ScaleFactorNVecSize + self.ScaleFactorKVecSize = ScaleFactorKVecSize + + if self.D == None: + self.D = self.C + + if not self.is_3x: + assert(kernel_schedule == KernelScheduleType.ScheduleAuto) + assert(epilogue_schedule == EpilogueScheduleType.ScheduleAuto) + self.kernel_schedule = kernel_schedule + self.epilogue_schedule = epilogue_schedule + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + + if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination: + self.epilogue_functor = EpilogueFunctor3x.LinearCombination + + self.swizzling_functor = swizzling_functor + self.tile_scheduler = tile_scheduler + + # Only enable mixed input mode and mixed input shuffle for Hopper + self.mixed_input_mode = None + if self.is_mixed_input() and self.arch >= 90 and self.arch < 100: + self.mixed_input_mode = mixed_input_mode + self.mixed_input_shuffle = (self.mixed_input_mode is not None) and mixed_input_shuffle + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def is_planar_complex(self): + return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and', + MathOperation.multiply_add_fast_accum: 'fastaccum', + } + + tensor_ops = [ + OpcodeClass.TensorOp, + OpcodeClass.WmmaTensorOp, + OpcodeClass.SparseTensorOp, + OpcodeClass.BlockScaledTensorOp, + ] + + is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops + + if is_tensor_op: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else "" + + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + short_math_name = self.short_math_name() if not self.is_3x else "" + + return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) + + # Generates a string representing the MMA instruction. + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + element_sfa = "" + element_sfb = "" + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.is_mixed_input(): + extended_name = "${core_name}_${element_a}_${element_b}" + if self.C.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_" + extended_name + elif is_blockwise(self.gemm_kind): + extended_name = "${core_name}_${element_sfa}x${element_a}_${element_sfb}x${element_b}" + element_sfa = DataTypeNames[self.accumulator_type()] + element_sfb = DataTypeNames[self.accumulator_type()] + else: + extended_name = "${core_name}" + if self.C.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_" + extended_name + if self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name += "_${element_a}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_sfa' : element_sfa, + 'element_b': DataTypeNames[self.B.element], + 'element_sfb' : element_sfb, + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def mixed_input_mode_name(self): + mode_name_mapping = { + MixedInputMode.ConvertOnly: "_cvt", + MixedInputMode.ScaleOnly: "_scl", + MixedInputMode.ScaleWithZeroPoint: "_sclzr" + } + mode_name = mode_name_mapping.get(self.mixed_input_mode, "") + if self.mixed_input_shuffle: + mode_name = mode_name + "_shfl" + return mode_name + + def extended_name_3x(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_a = DataTypeNames[self.A.element], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = DataTypeNames[self.D.element], + core_name = self.core_name()) + + if is_block_scaled(self.gemm_kind): + d_type_names = DataTypeNames[self.D.element] + + if self.ScaleFactorD.element != DataType.void: + d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names + + extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_sfa = DataTypeNames[self.ScaleFactorA], + element_a = DataTypeNames[self.A.element], + element_sfb = DataTypeNames[self.ScaleFactorB], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = d_type_names, + core_name = self.core_name()) + + if is_blockwise(self.gemm_kind): + d_type_names = DataTypeNames[self.D.element] + + extended_name = "{core_name}_{sfvec_m_size}x{sfvec_k_size}{element_sfa}x{element_a}_{sfvec_n_size}x{sfvec_k_size}{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_sfa = DataTypeNames[self.accumulator_type()], + element_a = DataTypeNames[self.A.element], + element_sfb = DataTypeNames[self.accumulator_type()], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = d_type_names, + sfvec_m_size = self.ScaleFactorMVecSize, + sfvec_n_size = self.ScaleFactorNVecSize, + sfvec_k_size = self.ScaleFactorKVecSize, + core_name = self.core_name()) + + if self.mixed_input_mode != None: + extended_name = extended_name + self.mixed_input_mode_name() + return extended_name + + def datatype_name_3x(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_a = DataTypeNames[self.A.element], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = DataTypeNames[self.D.element]) + return datatype_name + + # Generates a short string representing the AB layout tags (e.g. nt or tn) + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + # Generates a short string representing the ABC layout tags (e.g. ntn or tnn) + def layout_name_3x(self): + if self.is_complex() or self.is_planar_complex(): + return "{}{}{}".format( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], + ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) + else: + return "{}{}{}".format( + ShortLayoutTypeNames[self.A.layout], + ShortLayoutTypeNames[self.B.layout], + ShortLayoutTypeNames[self.C.layout]) + + # Generates a short string representing underlying kernel schedule type + def kernel_schedule_name_3x(self): + return KernelScheduleSuffixes[self.kernel_schedule] + + # Generates a short string representing underlying epilogue schedule type + def epilogue_schedule_name_3x(self): + + if is_block_scaled(self.gemm_kind): + if self.ScaleFactorD.element != DataType.void: + return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout] + + return EpilogueScheduleSuffixes[self.epilogue_schedule] + + # Generate a short string representing the operation class + def opcode_class_name(self): + return OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + def get_collective_tile_shape(self): + """ + Get the tile shape passed to the collective builder. + On Blackwell, this is different than the operation.tile_description.tile_shape. + """ + is_sm100_kernel = (self.arch == 100 or self.arch == 103) + if not is_sm100_kernel: + return self.tile_description.tile_shape + + opcode_class_main = self.tile_description.math_instruction.opcode_class + instruction_shape = self.tile_description.math_instruction.instruction_shape + tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape + if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]: + tile_shape_m = instruction_shape[0] + tile_shape_n = instruction_shape[1] + return (tile_shape_m, tile_shape_n, tile_shape_k) + + # Generates the full kernel function name + def procedural_name(self): + return self._procedural_name + + @functools.cached_property + def _procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + if self.arch >= 90: + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}" + tile_shape = self.get_collective_tile_shape() + return kernel_name_template.format( + p = self.prefix, + ar = self.arch, + op = opcode_class_name, + ex = self.extended_name_3x(), + ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "", + cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]), + l = self.tile_description.stages, + s = self.layout_name_3x(), + al = str(max(self.A.alignment, self.B.alignment)), + t = TileSchedulerSuffixes[self.tile_scheduler], + k = self.kernel_schedule_name_3x(), + e = self.epilogue_schedule_name_3x()) + else: + threadblock = self.tile_description.procedural_name() + return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( + p = self.prefix, + op = opcode_class_name, + ex = self.extended_name(), + tb = threadblock, + l = self.layout_name(), + a = str(max(self.A.alignment, self.B.alignment))) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + + def __hash__(self): + return hash(self.configuration_name()) + + def __eq__(self, other): + return self.configuration_name() == other.configuration_name() + +################################################################################################### +# +# Data structure modeling a grouped GEMM operation +# +################################################################################################### + +# +class GroupedGemmOperation(GemmOperation): + # + def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + scheduler_mode = GroupScheduleMode.Device): + super().__init__(gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor, swizzling_functor) + + self.scheduler_mode = scheduler_mode + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + base = super().procedural_name() + return SubstituteTemplate( + base + "_schedule${schedule}", + { + 'schedule': ShortGroupScheduleModeNames[self.scheduler_mode] + }) + + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitGemmInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::Gemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + self.gemm_complex_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::GemmComplex< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${transform_a}, + ${transform_b}, + ${math_operation} + ${residual} + >; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + residual = '' + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual + } + + template = self.gemm_complex_template if operation.is_complex() else self.gemm_template + + return SubstituteTemplate(template, values) + +################################################################################################### + +class EmitSparseGemmInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::SparseGemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + residual = '' + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual + } + + template = self.gemm_template + + return SubstituteTemplate(template, values) + +################################################################################################### + + +# +class EmitGemmUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/device/gemm_universal_adapter.h", + "cutlass/gemm/kernel/default_gemm_universal.h", + ] + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > +""" + self.gemm_template = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_interleaved = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + transpose_layouts = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor + } + + if operation.A.layout in transpose_layouts.keys() and \ + operation.B.layout in transpose_layouts.keys() and \ + operation.C.layout in transpose_layouts.keys(): + + instance_layout_A = transpose_layouts[operation.A.layout] + instance_layout_B = transpose_layouts[operation.B.layout] + instance_layout_C = transpose_layouts[operation.C.layout] + + gemm_template = self.gemm_template + else: + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + + gemm_template = self.gemm_template_interleaved + # + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + + epilogue_vector_length = \ + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] + + values = { + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + # + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_functor': epilogue_functor, + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] + } + + return SubstituteTemplate(gemm_template, values) + + +################################################################################################### + +class EmitGemmUniversal3xInstance: + ''' Responsible for emitting a CUTLASS 3.x template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + "cutlass/detail/blockwise_scale_layout.hpp", + ] + self.builtin_epilogue_functor_template = \ +"""${epilogue_functor}< + ${element_d}, + ${element_epilogue}, + ${element_c}, + ${element_epilogue} + >""" + + self.gemm_template = """ + +using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class_epi}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${epi_tile_mn}, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + ${epilogue_schedule}, + ${epilogue_functor} + >::CollectiveOp; + +${mixed_dtype_prepare_code} +${blockwise_prepare_code} + +using ${operation_name}_mainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class_main}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${stages}, + ${kernel_schedule} + >::CollectiveOp; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + ${problem_shape}, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler}>; + +// Define named type +struct ${operation_name} : + public ${operation_name}_base { }; + +""" + # + def instance_template(self): + return """ +${compile_guard_start} + { + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); + } +${compile_guard_end} +""" + + + def emit_block_scale_epilogue_functor(self, operation): + block_scaled_template = """ + ${epilogue_functor}< + ${epi_vs}, + ${element_d}, + ${element_accumulator}, + ${element_sfd}, + ${layout_sfd}, + ${element_c}, + ${element_scalar} + > + """ + block_scaled_values = { + 'epi_vs' : str(operation.ScaleFactorVectorSize), + 'element_d': str(DataTypeTag[operation.D.element]), + 'element_sfd': str(DataTypeTag[operation.ScaleFactorD.element]), + 'layout_sfd': LayoutTag[operation.ScaleFactorD.layout], + 'epilogue_functor': EpilogueFunctor3xTag[EpilogueFunctor3x.LinearCombinationBlockScaleFactor], + 'element_accumulator': str(DataTypeTag[operation.accumulator_type()]), + 'element_scalar': str(DataTypeTag[operation.accumulator_type()]), + 'element_c': str(DataTypeTag[operation.C.element]), + } + return SubstituteTemplate(block_scaled_template, block_scaled_values) + + + @staticmethod + def pointerize_if_grouped(operation, layout): + return layout if not is_grouped(operation.gemm_kind) else layout + "* " + + @staticmethod + def transform_layout_A_if_blockwise(operation, layout): + layout_sfa = f"{operation.procedural_name()}_LayoutSFA" + layout_sfa = layout_sfa if not is_grouped(operation.gemm_kind) else layout_sfa + "* " + return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfa}>" + + @staticmethod + def transform_layout_B_if_blockwise(operation, layout): + layout_sfb = f"{operation.procedural_name()}_LayoutSFB" + layout_sfb = layout_sfb if not is_grouped(operation.gemm_kind) else layout_sfb + "* " + return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfb}>" + + @staticmethod + def problem_shape(operation): + gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">" + + return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type + + def emit(self, operation): + _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") + _LOGGER.debug("*** operation.procedural_name(): " + operation.procedural_name()) + _LOGGER.debug("*** tile_shape: " + str(operation.tile_description.tile_shape)) + _LOGGER.debug("*** warp_count: " + str(operation.tile_description.warp_count)) + + opcode_class_main = operation.tile_description.math_instruction.opcode_class + opcode_class_epi = opcode_class_main + + tile_shape = operation.tile_description.tile_shape + instruction_shape = operation.tile_description.math_instruction.instruction_shape + cluster_m = operation.tile_description.cluster_shape[0] + cluster_n = operation.tile_description.cluster_shape[1] + cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1] + tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape() + + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" + elif opcode_class_main == OpcodeClass.SparseTensorOp and operation.arch == 100: + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveoutEpi<{str(operation.procedural_name())}_epilogue>" + else: + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>" + + epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" + + instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \ + (operation.A.layout, operation.B.layout, operation.C.layout, operation.D.layout) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + + if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void: + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + + + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + + if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void: + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + + # + # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple, Transform : cute::identity / cute::conjugate. + element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>" + element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" + epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] + + if opcode_class_main == OpcodeClass.BlockScaledTensorOp: + grouped = is_grouped(operation.gemm_kind) + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped): + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped): + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] + # SM103 FP4 Ultra + is_sm103_fp4_ultra_1sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped) + ] + is_sm103_fp4_ultra_2sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped) + ] + if cta_n == 256 and is_sm103_fp4_ultra_1sm_kernel_schedule: + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] + if cta_n == 256 and is_sm103_fp4_ultra_2sm_kernel_schedule: + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] + + element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>' + element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' + + alignment_c = get_tma_alignment(operation.C.element) \ + if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \ + else operation.C.alignment + alignment_d = get_tma_alignment(operation.D.element) \ + if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \ + else operation.D.alignment + + operation_name_str = operation.procedural_name() + layout_a_str = LayoutTag[instance_layout_A] + layout_b_str = LayoutTag[instance_layout_B] + mixed_dtype_prepare_code = "" + if operation.mixed_input_mode != None: + A_dtype = operation.A.element + B_dtype = operation.B.element + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) + + narrow_tag = DataTypeTag[narrow_dtype] + wide_tag = DataTypeTag[wide_dtype] + scale_tag = DataTypeTag[wide_dtype] + zero_tag = DataTypeTag[wide_dtype] + + do_shuffle = False + value_shuffle_str = "" + if narrow_dtype_bits == 4 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, cute::Stride>" + do_shuffle = True + if narrow_dtype_bits == 8 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, cute::Stride>" + do_shuffle = True + do_shuffle = operation.mixed_input_shuffle and do_shuffle + + if do_shuffle: + if is_A_dtype_narrow: + stride_narrow_str = f"cutlass::detail::TagToStrideA_t<{layout_a_str}>" + layout_a_str = f"{operation_name_str}_LayoutNarrowReordered" + else: + stride_narrow_str = f"cutlass::detail::TagToStrideB_t<{layout_b_str}>" + layout_b_str = f"{operation_name_str}_LayoutNarrowReordered" + # The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and + # layout_{a, b}_str are to prevent errors in Windows platform unity build + mixed_dtype_prepare_code = f""" +using {operation_name_str}_StrideNarrow = {stride_narrow_str}; +using {operation_name_str}_ValueShuffle = {value_shuffle_str}; +static constexpr int {operation_name_str}_NumShuffleAtoms = 1; +using {operation_name_str}_MmaAtomShape = cute::Layout>>; +using {operation_name_str}_LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, {operation_name_str}_ValueShuffle>()); +using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, cute::Layout, {operation_name_str}_StrideNarrow>{{}})); + """ + + mixed_input_modes_to_element = { + MixedInputMode.ConvertOnly: narrow_tag, + MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>", + MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>" + } + narrow_element = mixed_input_modes_to_element.get(operation.mixed_input_mode, narrow_tag) + + if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2): + narrow_element = f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>" + + if is_A_dtype_narrow: + element_a = narrow_element + else: + element_b = narrow_element + + blockwise_prepare_code = "" + if is_blockwise(operation.gemm_kind): + sfm_vec_size = operation.ScaleFactorMVecSize + sfn_vec_size = operation.ScaleFactorNVecSize + sfk_vec_size = operation.ScaleFactorKVecSize + blockwise_prepare_code = f""" +using {operation_name_str}_ScaleConfig = cutlass::detail::Sm{operation.arch}BlockwiseScaleConfig<{sfm_vec_size}, {sfn_vec_size}, {sfk_vec_size}>; +using {operation_name_str}_LayoutSFA = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFA()); +using {operation_name_str}_LayoutSFB = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFB()); + """ + + values = { + 'operation_name': operation_name_str, + 'operation_suffix': self.operation_suffix, + 'problem_shape': self.problem_shape(operation), + 'element_a': element_a, + 'layout_a': self.transform_layout_A_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_a_str)), + 'element_b': element_b, + 'layout_b': self.transform_layout_B_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_b_str)), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]), + 'element_d': DataTypeTag[operation.D.element], + 'layout_d': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_D]), + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class_main': OpcodeClassTag[opcode_class_main], + 'opcode_class_epi': OpcodeClassTag[opcode_class_epi], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'tile_shape_m': str(tile_shape_m), + 'tile_shape_n': str(tile_shape_n), + 'tile_shape_k': str(tile_shape_k), + 'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int", + 'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int", + 'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int", + 'instruction_shape_m': str(instruction_shape[0]), + 'instruction_shape_n': str(instruction_shape[1]), + 'instruction_shape_k': str(instruction_shape[2]), + 'kernel_schedule' : str(KernelScheduleTag[operation.kernel_schedule]), + 'epilogue_schedule' : str(epilogue_schedule_type), + 'epi_tile_mn' : epi_tile_mn, + 'epilogue_functor': epilogue_functor, + 'stages': stage_count_string, + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'align_c': str(alignment_c), + 'align_d': str(alignment_d), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]), + 'mixed_dtype_prepare_code': mixed_dtype_prepare_code, + 'blockwise_prepare_code' : blockwise_prepare_code + } + + return SubstituteTemplate(self.gemm_template, values) + +################################################################################################### + +# +class EmitGemmPlanarComplexInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator} + >::GemmKernel; + + struct ${operation_name} : + public Operation_${operation_name} { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### + +# +class EmitGemmPlanarComplexArrayInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator} + >::GemmArrayKernel; + + struct ${operation_name} : public Operation_${operation_name} { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### + +# +class EmitGemmGroupedInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/kernel/gemm_grouped.h", + "cutlass/gemm/kernel/default_gemm_grouped.h", + "cutlass/gemm/device/gemm_grouped.h" + ] + self.builtin_epilogue_functor_template = \ +"""${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >""" + + self.gemm_template = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmGrouped< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${scheduler_mode}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmGrouped<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + transpose_layouts = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor + } + + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + # + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + + epilogue_vector_length = \ + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] + + values = { + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + # + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_functor': epilogue_functor, + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'scheduler_mode': GroupScheduleModeTag[operation.scheduler_mode], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] + } + + return SubstituteTemplate(self.gemm_template, values) + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitGemmConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + GemmKind.Gemm: EmitGemmInstance, + GemmKind.Sparse: EmitSparseGemmInstance, + GemmKind.Universal: EmitGemmUniversalInstance, + GemmKind.Universal3x: EmitGemmUniversal3xInstance, + GemmKind.SparseUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, + GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance, + GemmKind.Grouped: EmitGemmGroupedInstance, + GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.BlockwiseUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.GroupedBlockwiseUniversal3x: EmitGemmUniversal3xInstance, + } + + self.gemm_kind_wrappers = { + GemmKind.Gemm: 'GemmOperation', + GemmKind.Sparse: 'GemmSparseOperation', + GemmKind.Universal: 'GemmUniversalOperation', + GemmKind.Universal3x: 'GemmUniversal3xOperation', + GemmKind.SparseUniversal3x: 'SparseGemmUniversal3xOperation', + GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation', + GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', + GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation', + GemmKind.Grouped: 'GemmGroupedOperation', + GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation', + GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation', + GemmKind.BlockwiseUniversal3x: 'BlockwiseGemmUniversal3xOperation', + GemmKind.GroupedBlockwiseUniversal3x: 'GroupedBlockwiseGemmUniversal3xOperation', + } + + self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" + + self.separator = """ +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.header_template = """ +/* + Generated by gemm_operation.py - Do not edit. +*/ +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + _LOGGER.debug("*** EmitGemmConfigurationLibrary::__enter__") + _LOGGER.debug("*** configuration_path (file to write): " + + str(self.configuration_path)) + + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + self.configuration_file.write(self.separator) + + self.includes = collections.OrderedDict([ + ("cutlass/cutlass.h", None), + ("cutlass/library/library.h", None), + ("cutlass/library/manifest.h", None), + ("library_internal.h", None), + ("gemm_operation.h", None), + ("gemm_operation_3x.hpp", None), + ("grouped_gemm_operation_3x.hpp", None), + ("sparse_gemm_operation_3x.hpp", None), + ("block_scaled_gemm_operation_3x.hpp", None), + ("blockwise_gemm_operation_3x.hpp", None), + ("cutlass/arch/wmma.h", None), + ("cutlass/numeric_types.h", None) + ]) + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") + _LOGGER.debug("*** operation.gemm_kind: " + str(operation.gemm_kind)) + + emitter = self.instance_emitter[operation.gemm_kind]() + + for incl in emitter.includes: + self.includes[incl] = None + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(emitter.instance_template(), { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write includes + for incl, _ in self.includes.items(): + include_statement = "#include \"%s\"\n" % incl + self.configuration_file.write(include_statement) + + self.configuration_file.write(self.separator) + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### +################################################################################################### diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/generator.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..063e8fb1caa6626e8ba099133fee4dd3dc115e40 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/generator.py @@ -0,0 +1,10962 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library kernels +""" + +import argparse +import enum +from itertools import chain, product +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Dict, Optional, Sequence, Tuple + +_LOGGER = logging.getLogger(__name__) + +def logging_prefix(indent_level: int = 0) -> str: + """String prefix for start of each debug log entry""" + prefix = '*** ' + indent = ' ' + return f"{prefix}{indent_level * indent}" + +def log_debug_line(line: str, indent_level: int = 0) -> None: + """Log one line of debug output""" + prefix = logging_prefix(indent_level) + _LOGGER.debug(prefix + line) + +# Certain usecases of cutlass_library nearly always prefer to run as scripts with +# relative imports, rather than via an installed Python package. An example of this +# is using CUTLASS's CMake system to generate a library of kernels to be profiled. +# To make it easy to use these use cases when an existing installation of cutlass_library +# exists, this global flag can be set to true (via command-line arguments) to ensure +# that package-based installations are not used. + +# Create a temporary argument parser to check only for the availability of the +# --disable-cutlass-package-imports argument, which controls whether package-based +# imports are disabled. +def _add_package_disablement_flag(argparser): + argparser.add_argument("--disable-cutlass-package-imports", action='store_true', required=False, + help="Disable use of cutlass_library from Python package") + +_parser = argparse.ArgumentParser() +_add_package_disablement_flag(_parser) +_args, _ = _parser.parse_known_args() + +# Add `CUTLASS_IGNORE_PACKAGE` to `builtins` so that it is visible for gating future +# imports without requiring importing another module. Ideally, we would just place this +# as a global variable in a module to that could be imported and checked (e.g., +# utils.CUTLASS_IGNORE_PACKAGE). However, this raises the issue of determining +# where this module should be sourced (from the cutlass_library package or from +# a relative import), which is the problem this variable is being used to solve in the +# first place. +import builtins +builtins.CUTLASS_IGNORE_PACKAGE = _args.disable_cutlass_package_imports + +try: + if CUTLASS_IGNORE_PACKAGE: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.manifest import * + from cutlass_library.heuristics import * + from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist +except ImportError: + from library import * + from manifest import * + from heuristics import * + from emit_kernel_listing import emit_gemm_kernel_testlist +################################################################################################### + +# +def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): + + # by default, use the latest CUDA Toolkit version + cuda_version = [11, 0, 132] + + # Update cuda_version based on parsed string + if semantic_ver_string != '': + for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]): + if i < len(cuda_version): + cuda_version[i] = x + else: + cuda_version.append(x) + return cuda_version >= [major, minor, patch] + +# From cuda 13.0, Thor SM is renumbered from 101 to 110 +def ThorSMRenumbering(cuda_version): + return 110 if CudaToolkitVersionSatisfies(cuda_version, 13, 0) else 101 + +################################################################################################### +################################################################################################### + +# +def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8): + ''' Helper to compute the maximum alignment of the epilogue ''' + + def product(X, identity = 1): + result = identity + for item in X: + result *= item + return result + + elements_per_thread = product(tile.threadblock_shape[:-1]) // product(tile.warp_count) // 32 // epilogue_steps + return min(max_alignment, elements_per_thread) + +def DefaultSwizzlingFunctor(): + return SwizzlingFunctor.Identity8 + # To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK` + +# +def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = DefaultSwizzlingFunctor()): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + # If alignment is a tuple or a list, then we have different alignments for A and B + alignment_a = alignment if isinstance(alignment, int) else alignment[0] + alignment_b = alignment if isinstance(alignment, int) else alignment[1] + alignment_c = min(8, alignment_a) if isinstance(alignment, int) else alignment[2] + + A = TensorDescription(element_a, layout[0], alignment_a, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment_b, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts +def CreateGemmUniversal3xOperator( + manifest, layouts, tile_descriptions, data_types, + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]], + complex_transforms=None, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity1, + tile_schedulers=[TileSchedulerType.Default], + gemm_kind=GemmKind.Universal3x): + + if type(data_types) is dict: + data_types = [data_types] + + for s in schedules: + assert(len(s) == 2) + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none), ] + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + if len(tile_descriptions) == 0: + return operations + tile_descriptions = [tile_descriptions[0]] + + combinations = product(layouts, tile_descriptions, data_types, complex_transforms, schedules, tile_schedulers) + for layout, tile_description, data_type, complex_transform, schedules, tile_scheduler in combinations: + kernel_schedule, epilogue_schedule = schedules + A = TensorDescription( + data_type["a_type"], layout[0][0], layout[0][1], complex_transform[0]) + B = TensorDescription( + data_type["b_type"], layout[1][0], layout[1][1], complex_transform[1]) + + C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) + D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) + + gemm_op_extra_args = {} + element_compute = data_type.get("epi_type", data_type["acc_type"]) + + if "sf_type" in data_type: + gemm_op_extra_args["ScaleFactorA"] = data_type["sf_type"] + gemm_op_extra_args["ScaleFactorB"] = data_type["sf_type"] + gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]), + "vector_size" : data_type["sfd_type"]["vector_size"]} + assert is_block_scaled(gemm_kind) + + if tile_description.explicit_vector_sizes != None: + assert len(tile_description.explicit_vector_sizes) == 3 + gemm_op_extra_args["ScaleFactorMVecSize"] = tile_description.explicit_vector_sizes[0] + gemm_op_extra_args["ScaleFactorNVecSize"] = tile_description.explicit_vector_sizes[1] + gemm_op_extra_args["ScaleFactorKVecSize"] = tile_description.explicit_vector_sizes[2] + assert is_blockwise(gemm_kind) + else: + assert not is_blockwise(gemm_kind) + + A_dtype = data_type["a_type"] + B_dtype = data_type["b_type"] + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) + + mixed_input_modes = [None] + if narrow_dtype_bits != wide_dtype_bits: + if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2): + mixed_input_modes = [MixedInputMode.ScaleOnly] + else: + mixed_input_modes = [MixedInputMode.ConvertOnly, MixedInputMode.ScaleOnly, MixedInputMode.ScaleWithZeroPoint] + + mixed_input_shuffle_options = [False] + if (mixed_input_modes[0] is not None) and (wide_dtype_bits == 16) and (narrow_dtype_bits == 4 or narrow_dtype_bits == 8): + mixed_input_shuffle_options = [False, True] + + for mixed_input_mode, mixed_input_shuffle in product(mixed_input_modes, mixed_input_shuffle_options): + operation = GemmOperation( + gemm_kind, tile_description.minimum_compute_capability, + tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, + kernel_schedule, epilogue_schedule, tile_scheduler, + mixed_input_mode=mixed_input_mode, mixed_input_shuffle=mixed_input_shuffle, **gemm_op_extra_args) + manifest.append(operation) + operations.append(operation) + + return operations + +# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts +def CreateSparseGemmUniversal3xOperator( + manifest, layouts, tile_descriptions, data_types, + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]], + complex_transforms=None, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity1, + tile_schedulers=[TileSchedulerType.Default]): + + if type(data_types) is dict: + data_types = [data_types] + + for s in schedules: + assert(len(s) == 2) + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none), ] + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0]] + + combinations = product(layouts, tile_descriptions, data_types, complex_transforms, schedules, tile_schedulers) + for layout, tile_description, data_type, complex_transform, schedules, tile_scheduler in combinations: + kernel_schedule, epilogue_schedule = schedules + A = TensorDescription( + data_type["a_type"], layout[0][0], layout[0][1], complex_transform[0]) + B = TensorDescription( + data_type["b_type"], layout[1][0], layout[1][1], complex_transform[1]) + + # Currently assume tensor C/D have same layout requirement. + C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) + D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) + + element_compute = data_type.get("epi_type", data_type["acc_type"]) + + operation = GemmOperation( + GemmKind.SparseUniversal3x, tile_description.minimum_compute_capability, + tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, + kernel_schedule, epilogue_schedule, tile_scheduler) + + manifest.append(operation) + operations.append(operation) + + return operations + +# +def CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + gemm_kinds = [GemmKind.Sparse] + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GemmOperation(GemmKind.Sparse, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + gemm_kinds = [GemmKind.PlanarComplex, GemmKind.PlanarComplexArray] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for gemm_kind in gemm_kinds: + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + manifest.append(GemmOperation(gemm_kind, \ + tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue)) + return + +# +def CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GroupedGemmOperation(GemmKind.Grouped, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, data_type, \ + alignment_constraints, blas_mode, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + element_a, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for fill_mode in fill_modes: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + + # SERK supported layouts (RowMajor, ColumnMajor) with no conjugation + complex_transform = ComplexTransform.none + + # HERK supported layouts (RowMajor + conj, ColumnMajor) + if blas_mode == BlasMode.hermitian and layout[0] == LayoutType.RowMajor: + complex_transform = ComplexTransform.conj + + alignment_c = 1 # Alignment only applies to A in SYRK + + A = TensorDescription(element_a, layout[0], alignment, complex_transform) + C = SymmetricTensorDescription(element_c, layout[1], fill_mode, alignment_c) + + # Rank-K update + new_operation = RankKOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + # Rank-2K update + new_operation = Rank2KOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for side_mode in side_modes: + for fill_mode in fill_modes: + for diag_type in diag_types: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TriangularTensorDescription(element_a, layout[0], side_mode, fill_mode, diag_type, + alignment, complex_transform) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = TrmmOperation(TrmmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, data_type, \ + alignment_constraints, blas_mode, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for side_mode in side_modes: + for fill_mode in fill_modes: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + + # SYMM supported layouts (RowMajor, ColumnMajor) with no conjugation + complex_transform = ComplexTransform.none + + alignment_a = 1 # No vectorized access for the triangular matrix + alignment_c = min(8, alignment) + + A = SymmetricTensorDescription(element_a, layout[0], fill_mode, alignment_a, complex_transform, side_mode) + # tensor A and B have same data type and layout + B = TensorDescription(element_b, layout[0], alignment) + C = TensorDescription(element_c, layout[1], alignment_c) + + # SYMM/HEMM update + new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + # SYMM/HEMM update + new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +########################################################################################################### +# ConvolutionOperator support variations +# ____________________________________________________________________ +# ConvolutionalOperator | Analytic | Optimized +# ____________________________________________________________________ +# | Fprop | (strided) | (strided) +# | Dgrad | (strided, unity*) | (strided, unity) +# | Wgrad | (strided) | (strided) +# ____________________________________________________________________ +# +# Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low +########################################################################################################### +# Convolution for 2D operations +def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + iterator_algorithms = [IteratorAlgorithm.Optimized] + + operations = [] + + for tile in tile_descriptions: + for alignment in alignment_constraints: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operations = [ + # None grouped kernel + Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_), + ] + + # Instance group conv kernel + if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC and \ + tile.minimum_compute_capability >= 80: + # SingleGroup kernel + new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup)) + + # Analytic iterator supports MultipleGroup mode + if iterator_algorithm == IteratorAlgorithm.Analytic: + new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup)) + + for new_operation in new_operations: + manifest.append(new_operation) + operations.append(new_operation) + + # + # Conv2d Dgrad + # + if ConvKind.Dgrad in conv_kinds: + + # Unity stride for Analytic and Optimized Dgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Analytic Dgrad + # strided dgrad uses a special threadblock swizzle + # note that SwizzlingFunctor.StridedDgradHorizontal might be + # better for problem sizes with large activation channel count + swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1 + + if IteratorAlgorithm.Analytic in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) + + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Optimized Dgrad + if IteratorAlgorithm.Optimized in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) + + manifest.append(new_operation) + operations.append(new_operation) + + # + # Conv2d Wgrad + # + if ConvKind.Wgrad in conv_kinds: + + # Strided support for Analytic and Optimized Wgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for 2D operations specialized for few channels +def CreateConv2dFixedChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.FixedChannels,] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + channel_counts = [channel_counts[0],] + + operations = [] + + + + for tile in tile_descriptions: + for channel_count in channel_counts: + + alignment_c = EpilogueAlignment(channel_count, tile) + + A = TensorDescription(element_a, layout[0], channel_count) + B = TensorDescription(element_b, layout[1], channel_count) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for 2D operations specialized for few channels +def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.FewChannels,] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + channel_counts = [channel_counts[0],] + + operations = [] + + for tile in tile_descriptions: + for channel_count in channel_counts: + + alignment_c = EpilogueAlignment(channel_count, tile) + + A = TensorDescription(element_a, layout[0], channel_count) + B = TensorDescription(element_b, layout[1], channel_count) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for 3D operations +def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + alignment_c = min(8, alignment) + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size and optimized iterators + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + iterator_algorithms = [IteratorAlgorithm.Optimized] + + operations = [] + + # All tile sizes for Conv3dFprop and Conv3dWgrad + for tile in tile_descriptions: + A = TensorDescription(element_a, layout, alignment) + B = TensorDescription(element_b, layout, alignment) + C = TensorDescription(element_c, layout, alignment_c) + + # + # Conv3d Fprop + # + if ConvKind.Fprop in conv_kinds: + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operation = Conv3dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided) + manifest.append(new_operation) + operations.append(new_operation) + # + # Conv3d Wgrad + # + if ConvKind.Wgrad in conv_kinds: + + # Strided support for Analytic and Optimized Wgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv3dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) + manifest.append(new_operation) + operations.append(new_operation) + + # All tile sizes for Conv3dDgrad + for tile in tile_descriptions: + + A = TensorDescription(element_a, layout, alignment) + B = TensorDescription(element_b, layout, alignment) + C = TensorDescription(element_c, layout, alignment_c) + + # + # Conv3d Dgrad + # + if ConvKind.Dgrad in conv_kinds: + # Unity stride for Optimized Dgrad + new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Analytic Dgrad + # Conv3dDgrad has a naive strided support which does not cut down redundant MMAs + new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for Depthwise 2d conv +def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # iterator algorithm (FixedStrideDilation, Optimized) + iterator_algorithms = [IteratorAlgorithm.FixedStrideDilation, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + operations = [] + + for tile in tile_descriptions: + for alignment in alignment_constraints: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + if ConvKind.Fprop in conv_kinds: + + # Strided support for Optimized and FixedStridedDilation Depthwise Conv + for iterator_algorithm in iterator_algorithms: + stride_support = StrideSupport.Strided + if iterator_algorithm == IteratorAlgorithm.FixedStrideDilation: + if tile.stride == [-1, -1] or tile.dilation == [-1,-1]: + continue + stride_support = StrideSupport.Fixed + + if iterator_algorithm == IteratorAlgorithm.Optimized: + if tile.stride != [-1, -1] or tile.dilation != [-1,-1]: + continue + new_operation = Conv2dOperation(ConvKind.Fprop, + iterator_algorithm, + tile.minimum_compute_capability, + tile, + A, B, C, + element_epilogue, + stride_support, + epilogue_functor, + swizzling_functor_, + group_mode=GroupMode.Depthwise) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +class ConvOperation3x: + """All parameters of a CUTLASS 3 convolution operation. + + Unlike CUTLASS 2 convolutions, CUTLASS 3 convolutions do not + distinguish between 2-D and 3-D convolutions by kernel class name. + Instead, for CUTLASS 3 convolutions, the tensor layouts encode + whether the convolution is 2-D or 3-D. Thus, this class deduces + the OperationKind (either Conv2d or Conv3d) from the layouts, + rather than taking it as a constructor parameter. + """ + def __init__(self, + conv_kind: ConvKind, + tile_description: TileDescription, + A: TensorDescription, + B: TensorDescription, + C: TensorDescription, + element_compute: Optional[DataType] = None, + D: Optional[TensorDescription] = None, + kernel_schedule: KernelScheduleType = KernelScheduleType.ScheduleAuto, + epilogue_schedule: EpilogueScheduleType = EpilogueScheduleType.ScheduleAuto, + tile_scheduler: TileSchedulerType = TileSchedulerType.Default, + log_indent_level: int = 1): + log_debug_line(f'ConvOperation3x::init: conv_kind: {conv_kind}', log_indent_level) + log_indent_level = log_indent_level + 1 + + self.conv_kind = conv_kind + self.tile_description = tile_description + self.A = A + self.B = B + self.C = C + self.element_compute = C.element if element_compute is None else element_compute + self.kernel_schedule = kernel_schedule + self.epilogue_schedule = epilogue_schedule + + self.arch = tile_description.minimum_compute_capability + self.tile_scheduler = tile_scheduler + if D == None: + self.D = C + else: + self.D = D + + self.is_3x = True + self.group_mode = GroupMode.NoneGroup # CUTLASS 3 convolutions currently aren't grouped + + operation_kind = None + for layout in (A.layout, B.layout, C.layout): + assert(isinstance(layout, LayoutType)) + new_operation_kind = convolution_tensor_layout_type_to_operation_kind(layout) + if operation_kind is None: + operation_kind = new_operation_kind + else: # CUTLASS 3 convolutions don't permit mixing 2-D and 3-D layouts. + assert(operation_kind == new_operation_kind) + assert(operation_kind is not None) + self.operation_kind = operation_kind + + def __str__(self): + return f"ConvOperation3x: operation_kind={self.operation_kind}, conv_kind={self.conv_kind}, tile_description={self.tile_description}" + + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + def is_mixed_input(self): + return self.A.element != self.B.element + + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + if self.is_complex(): + return get_complex_from_real(accum) + return accum + + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and', + } + + tensor_ops = [ + OpcodeClass.TensorOp, + OpcodeClass.WmmaTensorOp, + OpcodeClass.SparseTensorOp, + OpcodeClass.BlockScaledTensorOp, + ] + + is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops + + if is_tensor_op: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s" % (math_op_string, intermediate_type, ConvKindNames[self.conv_kind]) + + def extended_name(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + extended_name = "{core_name}_{element_a}{layout_a}_{element_b}{layout_b}_{element_acc}_{element_c}_{element_d}{layout_c}".format( + element_a = DataTypeNames[self.A.element], + layout_a = ShortLayoutTypeNames[self.A.layout], + element_b = DataTypeNames[self.B.element], + layout_b = ShortLayoutTypeNames[self.B.layout], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + layout_c = ShortLayoutTypeNames[self.C.layout], + element_d = DataTypeNames[self.D.element], + core_name = self.core_name()) + + return extended_name + + # Generates a short string representing underlying kernel schedule type + def kernel_schedule_name(self): + return KernelScheduleSuffixes[self.kernel_schedule] + + # Generates a short string representing underlying epilogue schedule type + def epilogue_schedule_name(self): + return EpilogueScheduleSuffixes[self.epilogue_schedule] + + # Generate a short string representing the operation class + def opcode_class_name(self): + return OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + # Generates the full kernel function name + def configuration_name(self): + ''' The full function name indicates architecture, extended name, tile size, and layout. ''' + kernel_name_template = "cutlass3x_sm{ar}_{op}_{ex}{ct}{cs}_{l}_align{al}{t}{k}{e}" + return kernel_name_template.format( + ar = self.arch, + op = self.opcode_class_name(), + ex = self.extended_name(), + ct = '_' + 'x'.join([str(i) for i in self.tile_description.tile_shape]) if self.tile_description.tile_shape[0] > 0 else "", + cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]), + l = self.tile_description.stages, + al = str(max(self.A.alignment, self.B.alignment)), + t = TileSchedulerSuffixes[self.tile_scheduler], + k = self.kernel_schedule_name(), + e = self.epilogue_schedule_name()) + + def procedural_name(self): + return self.configuration_name() + +def convolution_tensor_layout_type_to_operation_kind(layout: LayoutType) -> OperationKind: + if layout == LayoutType.TensorNHWC or layout == LayoutType.TensorKCSR: + return OperationKind.Conv2d + elif layout == LayoutType.TensorNDHWC or layout == LayoutType.TensorKCSRT: + return OperationKind.Conv3d + else: + raise RuntimeError(f'LayoutType {layout} does not have a corresponding OperationKind') + +def CreateConvOperator3x(manifest: Manifest, + dims_and_alignments: Sequence[Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]], + tile_descriptions: Sequence[Sequence[TileDescription]], + data_types, + schedule_pairs: Sequence[Tuple[KernelScheduleType, KernelScheduleType]] = \ + [(KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto)], + complex_transforms: Optional[Sequence[ComplexTransform]] = None, + tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Default], + conv_kind: ConvKind = ConvKind.Fprop, + log_indent_level: int = 1): + """ + Create zero or more CUTLASS 3 two-dimensional convolution operators. + + Create a CUTLASS 3 two-dimensional convolution operator + for all feasible combinations of the input parameters. + Add the operators to the manifest. + + dims_and_alignments: 3-level list. Each outer list term is a list [A, B, C]. + Each inner list (A, B, or C) has the form [num_spatial_dimensions, alignment]. + Both are integers; the first is the number of spatial dimensions + (currently, only 2 or 3 are supported), and the second is the byte alignment. + We deduce the operation_kind (either OperationKind.Conv2d or OperationKind.Conv3d) + from num_spatial_dimensions. + + This function doesn't take layouts, unlike the GEMM functions. + CUTLASS 3 convolutions currently support three input layouts: + + * TensorNWC for 1-D convolutions, + * TensorNHWC for 2-D convolutions, and + * TensorNDHWC for 3-D convolutions. + + Output (C and D) layouts are the same as input layouts, + except for Wgrad convolutions, where the layouts are + + * TensorKCS for 1-D convolutions, + * TensorKCSR for 2-D convolutions, and + * TensorKCSRT for 3-D convolutions. + + The output layouts are completely constrained by the input layouts + and the convolution kind. + + tile_descriptions: 2-level list. + Outer level has one list per math instruction. + Inner level has one TileDescription for each cluster shape. + + data_types: Either a single data_type dictionary, or a list of them. + Keys: 'a_type', 'b_type', 'c_type', 'd_type', 'acc_type', 'epi_type' + + complex_transforms: Optional list of pairs. + First element of each pair is the complex transform for A, and + second element of each pair is the complex transform for B. + + schedule_pairs: [(kernel_schedule, epilogue_schedule), ...] + + conv_kind: Convolution kind (Fprop, Dgrad, or Wgrad). + """ + log_debug_line('CreateConvOperator3x', log_indent_level) + log_indent_level = log_indent_level + 1 + log_debug_line(f'conv_kind: {conv_kind}', log_indent_level) + + for triple in dims_and_alignments: + assert(isinstance(triple, tuple) or isinstance(triple, list)) + assert(len(triple) == 3) + + spatial_dimensionality = None # to be determined by loop below + + for entry in triple: # [A, B, C] + assert(len(entry) == 2) + [dim, alignment] = entry + assert(type(dim) is int) + assert(dim == 2 or dim == 3) + assert(type(alignment) is int) + assert(alignment > 0) + if spatial_dimensionality is None: + spatial_dimensionality = dim + else: + # A, B, and C need to have the same spatial dimensionality + assert(spatial_dimensionality == dim) + + def input_and_output_layouts(spatial_dim: int, kind: ConvKind) -> Tuple[LayoutType, LayoutType]: + if spatial_dim == 1: + input_layout = LayoutType.TensorNWC + if kind == ConvKind.Wgrad: + output_layout = LayoutType.TensorKCS + else: + output_layout = input_layout + elif spatial_dim == 2: + input_layout = LayoutType.TensorNHWC + if kind == ConvKind.Wgrad: + output_layout = LayoutType.TensorKCSR + else: + output_layout = input_layout + elif spatial_dim == 3: + input_layout = LayoutType.TensorNDHWC + if kind == ConvKind.Wgrad: + output_layout = LayoutType.TensorKCSRT + else: + output_layout = input_layout + else: + assert(False) + return (input_layout, output_layout) + + def dims_to_layouts(A_B_C: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]) -> \ + Tuple[Tuple[LayoutType, int], Tuple[LayoutType, int], Tuple[LayoutType, int]]: + [A, B, C] = A_B_C + [spatial_dim, alignment] = A + [input_layout, output_layout] = input_and_output_layouts(spatial_dim, conv_kind) + return ((input_layout, A[1]), + (input_layout, B[1]), + (output_layout, C[1])) + + # layouts: list of triples (A, B, C). + # Each of A, B, and C has the form [layout, alignment]. + layouts = [dims_to_layouts(A_B_C) for A_B_C in dims_and_alignments] + + if type(data_types) is dict: + data_types = [data_types] + + for s in schedule_pairs: + assert(len(s) == 2) + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none)] + + # product produces a one-pass generator, so the loop must call it anew each time. + def make_combinations(): + return product( + layouts, + tile_descriptions, + data_types, + complex_transforms, + schedule_pairs, + tile_schedulers + ) + + operations = [] + for layout_triple, tile_description, data_type, complex_transform_pair, schedule_pair, tile_scheduler in make_combinations(): + A_layout, A_alignment = layout_triple[0] + A_xform = complex_transform_pair[0] + B_layout, B_alignment = layout_triple[1] + B_xform = complex_transform_pair[1] + C_layout, C_alignment = layout_triple[2] + D_layout = C_layout + D_alignment = C_alignment + + A = TensorDescription(data_type["a_type"], A_layout, A_alignment, A_xform) + B = TensorDescription(data_type["b_type"], B_layout, B_alignment, B_xform) + C = TensorDescription(data_type["c_type"], C_layout, C_alignment) + D = TensorDescription(data_type["d_type"], D_layout, D_alignment) + element_compute = data_type.get("epi_type", data_type["acc_type"]) + kernel_schedule, epilogue_schedule = schedule_pair + + operation = ConvOperation3x(conv_kind=conv_kind, + tile_description=tile_description, + A=A, + B=B, + C=C, + element_compute=element_compute, + D=D, + kernel_schedule=kernel_schedule, + epilogue_schedule=epilogue_schedule, + tile_scheduler=tile_scheduler, + log_indent_level=log_indent_level) + log_debug_line(f'Created ConvOperation3x: {str(operation)}', log_indent_level) + manifest.append(operation) + operations.append(operation) + + return operations + +################################################################################################### +################################################################################################### + +# +def GenerateSM50_Simt(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + MathInstruction( \ + [1, 1, 1], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 50 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + if math_inst.element_a == DataType.f32: + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM50_Simt_complex(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 50 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32, + ] + + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM50(manifest, cuda_version): + GenerateSM50_Simt(manifest, cuda_version) + GenerateSM50_Simt_complex(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM60_Simt(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 60 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# +def GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version): + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 60 + max_cc = 1024 + + alignment_constraints = [8,] + + filter_3x3 = [3, 3] + filter_5x5 = [5, 5] + + # [stride_h, stride_w] + # [-1, -1] means all stride size. + strides = [[-1,-1], [1, 1], [2, 2]] + # [dilation_h, dilation_w] + # [-1, -1] means all dilation size. + dilations = [[-1,-1], [1, 1], [2, 2]] + + #groups per thread block + g16 = 16 + g32 = 32 + g64 = 64 + + #output shape per thread block + npq_1x4x4 = [1, 4, 4] + npq_1x8x8 = [1, 8, 8] + npq_1x10x10 = [1, 10, 10] + + tile_descriptions = [] + for math_inst in math_instructions: + for stride, dilation in product(strides, dilations): + tile_descriptions.extend([ + # filter3x3 ThreadBlock_output, filter, stage, warp + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_3x3, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_3x3, 4, stride, dilation,[4, 1, 1], math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), + + # filter5x5 ThreadBlock_output, filter, stage, warp + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_5x5, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc) + ]) + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateDepthwiseConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM60(manifest, cuda_version): + GenerateSM60_Simt(manifest, cuda_version) + GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM61_Simt(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 4], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 61 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) +# + +# +def GenerateSM61(manifest, cuda_version): + GenerateSM61_Simt(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM70_TensorOp_884(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 75 + + alignment_constraints = [8, 4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) + +# +def GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 75 + + alignment_constraints = [8, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + + +# +def GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version): + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 16, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 16, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 1024 + + alignment_constraints = [8,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# +################################################################################################## +# + +def GenerateSM70(manifest, cuda_version): + GenerateSM70_TensorOp_884(manifest, cuda_version) + GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version) + + # To limit build size, WMMA GEMMs are disabled for now. + # + #GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM75_TensorOp_1688_FewChannels(manifest, cuda_version, math_inst): + + min_cc = 75 + max_cc = 1024 + + tile_descriptions = [ + TileDescription([128, 64, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 2], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [4, 8]) + CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [1, 2, 4]) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [4, 8]) + CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [1, 2, 4]) + +# +def GenerateSM75_TensorOp_1688(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [8, 4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) + + # Separate generator for 'few channels' specializations + GenerateSM75_TensorOp_1688_FewChannels(manifest, cuda_version, math_inst) + +# + +# +def GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [8, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + +# +def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 16], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 90 + + alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), + + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.s32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# + +# +def GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 16], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 90 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 8 +# + +# +def GenerateSM75_TensorOp_8832_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 32], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 32], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 89 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.s32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + elif op.tile_description.threadblock_shape[1] == 64: + op.C.alignment = 8 + else: + op.C.alignment = 8 + +# + +# +def GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 32], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 32], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 89 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 16 +# + +# +def GenerateSM75_TensorOp_88128(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 128], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.xor_popc), + ] + + min_cc = 75 + max_cc = { + MathOperation.xor_popc: 89, + MathOperation.and_popc: 90 + } + + alignment_constraints = [128,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 512], 2, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 256, 512], 2, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 256, 512], 2, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 64, 512], 2, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + ] + + data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + +# + +# +def GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 16, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.f32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) +# + +# +def GenerateSM75_Simt_complex(manifest, cuda_version): + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc) + ] + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32 + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +def GenerateSM75(manifest, cuda_version): + GenerateSM75_TensorOp_1688(manifest, cuda_version) + GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version) + GenerateSM75_TensorOp_8816_TN(manifest, cuda_version) + GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version) + GenerateSM75_TensorOp_8832_TN(manifest, cuda_version) + GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version) + GenerateSM75_TensorOp_88128(manifest, cuda_version) + #GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version) + GenerateSM75_Simt_complex(manifest, cuda_version) + + +################################################################################################### +################################################################################################### + +# +def GenerateSM80_TensorOp_16816(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8, 4, 2] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [4, 8]) + CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [4, 8]) + CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8) +# + +# +def GenerateSM80_SparseTensorOp_16832(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 32], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 32], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# + +# +def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8, ] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + +# +def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand A + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[16, 8, 8],] + + for math_inst in math_instructions: + tile_descriptions = [ + # 128x128 + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x64 + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x32 + TileDescription([128, 32, 64], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x16 + TileDescription([128, 16, 64], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 64], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_b != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_b, + math_inst.element_accumulator, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + for op in operations: + if (DataTypeSize[op.C.element] == 16) and \ + (op.tile_description.threadblock_shape[1] <= 32): + op.C.alignment = 4 + +# +def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.s8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.u8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.s8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.u8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.s8, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.u8, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[8, 16, 8],] + + for math_inst in math_instructions: + tile_descriptions = [ + # 128x128 + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x64 + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x32 + TileDescription([128, 32, 64], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x16 + TileDescription([128, 16, 64], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 64], 3, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 9, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 3, [2, 1, 1], math_inst, min_cc, max_cc), + # 256x16 + TileDescription([256, 16, 32], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 16, 32], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] <= 32: + op.C.alignment = 4 + +# +def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 32], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + smem_usage = 164 + + alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# + +def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand A + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s4, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[32, 16, 4],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + alignment_constraints = [[32, 16, 16],] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_b, + DataType.f32 + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand B + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[16, 32, 4],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + alignment_constraints = [[16, 32, 16],] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 64], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [16,] + + tile_descriptions = [ + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.s8, DataType.s8, DataType.s32, DataType.s32] + data_type_mixed = [DataType.s8, DataType.s8, DataType.s8, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 32], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16864_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 64], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 64], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + elif op.tile_description.threadblock_shape[1] == 64: + op.C.alignment = 8 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 128], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate) + + min_cc = 80 + max_cc = 1024 + alignment_constraints = [32,] + + tile_descriptions = [ + TileDescription([ 64, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.s4, DataType.s4, DataType.s32, DataType.s32] + data_type_mixed = [DataType.s4, DataType.s4, DataType.s4, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] > 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 64], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 64], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 16 +# + +# +def GenerateSM80_TensorOp_168256(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 256], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.xor_popc), + MathInstruction( \ + [16, 8, 256], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.and_popc), + ] + + min_cc = 80 + max_cc = { + MathOperation.xor_popc: 89, + MathOperation.and_popc: 90 + } + + alignment_constraints = [128,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 512], 3, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 256, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 64, 512], 4, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 256, 512], 4, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 128, 512], 5, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 64, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 128, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 64, 512], 10, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 128, 1024], 3, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 256, 1024], 3, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 64, 1024], 4, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 256, 1024], 4, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 128, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 64, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 128, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 64, 1024], 5, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + ] + + data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + +# + +# +def GenerateSM80_TensorOp_1688(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f16), + MathInstruction( \ + [16, 8, 8], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_bf16), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +def GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32) + + min_cc = 80 + max_cc = 1024 + + tile_descriptions = [ + TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + +# +def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + tile_descriptions = [ + TileDescription([128, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1, 2, 4] # Alignment only applies to A in SYRK + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32] + + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + # SYRK + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1, 2, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_1688_symm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + # A and B have same layouts + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [ + 1, 2, 4 + ] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + # SYMM + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_884(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_884_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + +# +def GenerateSM80_TensorOp_884_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64] + + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_884_rank_k_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) + +# + +# +def GenerateSM80_TensorOp_884_rank_k_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_884_trmm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_884_trmm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + + +# +def GenerateSM80_TensorOp_884_trmm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_884_symm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +################################################################################################### + +# +def GenerateSM80_Simt_f32(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 5, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + + +# +def GenerateSM80_Simt_f64(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + + +################################################################################################## +# +def GenerateSM80_Simt_complex(manifest, cuda_version): + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32 + ] + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + for math_inst in math_instructions: + + tile_descriptions = [ + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +################################################################################################### + +# +def GenerateSM80(manifest, cuda_version): + GenerateSM80_TensorOp_16816(manifest, cuda_version) + GenerateSM80_SparseTensorOp_16832(manifest, cuda_version) + GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version) + GenerateSM80_TensorOp_1688(manifest, cuda_version) + GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version) + GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version) + GenerateSM80_TensorOp_1688_complex(manifest, cuda_version) + # 3xTF32 + GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version) + GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version) + GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version) + GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version) + GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version) + GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_1688_symm(manifest, cuda_version) + GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884(manifest, cuda_version) + GenerateSM80_TensorOp_884_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version) + GenerateSM80_TensorOp_884_rank_k_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_rank_k_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_884_trmm(manifest, cuda_version) + GenerateSM80_TensorOp_884_trmm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_trmm_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_884_symm(manifest, cuda_version) + GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version) + GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version) + GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version) + GenerateSM80_TensorOp_16864_TN(manifest, cuda_version) + GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version) + GenerateSM80_TensorOp_168256(manifest, cuda_version) + GenerateSM80_Simt_f32(manifest, cuda_version) + GenerateSM80_Simt_f64(manifest, cuda_version) + GenerateSM80_Simt_complex(manifest, cuda_version) + +################################################################################################### + +def GenerateSM89_TensorOp_16832_fp8(manifest, element_acc): + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) + ] + + math_instructions = [ + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + ] + + min_cc = 89 + max_cc = 100 + alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 6, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 6, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_types = [ + [ + math_inst.element_a, + math_inst.element_b, + DataType.f32, + math_inst.element_accumulator + ], + [ + math_inst.element_a, + math_inst.element_b, + DataType.bf16, + math_inst.element_accumulator + ], + ] + + operations = [] + for data_type in data_types: + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, + alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +def GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 4): + return + + GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f32) + +def GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f16) + +# +def GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version): + + if ( + not CudaToolkitVersionSatisfies(cuda_version, 12, 4) + ): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) + ] + + math_instructions = [ + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + ] + + min_cc = 89 + max_cc = 89 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_types = [ + [ + math_inst.element_a, + math_inst.element_b, + DataType.f32, + math_inst.element_accumulator + ], + ] + + operations = [] + for data_type in data_types: + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, + alignment_constraints, None, EpilogueFunctor.LinearCombination) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +################################################################################################### + +# +def GenerateSM89(manifest, cuda_version): + GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version) + GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version) + GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version) + +################################################################################################### + + +try: + from .sm90_utils import ( + generate_fp16_bf16_math_instructions_sm90, + generate_tf32_math_instructions_sm90, + generate_int8_math_instructions_sm90, + generate_fp8_math_instructions_sm90, + generate_mixed_dtype_math_instructions_sm90, + make_sparse_math_instructions, + generate_tile_descriptions_sm90, + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) +except ImportError: + from sm90_utils import ( + generate_fp16_bf16_math_instructions_sm90, + generate_tf32_math_instructions_sm90, + generate_int8_math_instructions_sm90, + generate_fp8_math_instructions_sm90, + generate_mixed_dtype_math_instructions_sm90, + make_sparse_math_instructions, + generate_tile_descriptions_sm90, + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) + +def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_fp16_bf16_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_types = [data_type_w_source, data_type_wo_source] + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed_w_source = generate_data_types_from_math_instruction( + math_inst, + element_source=math_inst.element_a, + element_dest=math_inst.element_a + ) + data_type_mixed_wo_source = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=math_inst.element_a + ) + data_types.append(data_type_mixed_w_source) + data_types.append(data_type_mixed_wo_source) + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + gemm_kind=gemm_kind, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_fp16_bf16_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_types = [data_type_w_source] + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed_w_source = generate_data_types_from_math_instruction( + math_inst, + element_source=math_inst.element_a, + element_dest=math_inst.element_a + ) + data_types.append(data_type_mixed_w_source) + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + +def GenerateSM90_SparseTensorOp_16b_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = make_sparse_math_instructions(generate_fp16_bf16_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_types = [data_type_w_source, data_type_wo_source] + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed_w_source = generate_data_types_from_math_instruction( + math_inst, + element_source=math_inst.element_a, + element_dest=math_inst.element_a + ) + data_type_mixed_wo_source = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=math_inst.element_a + ) + data_types.append(data_type_mixed_w_source) + data_types.append(data_type_mixed_wo_source) + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + ] + + math_instructions = generate_tf32_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + + for layout in layouts: + data_type_tf32 = generate_data_types_from_math_instruction(math_inst) + data_type_tf32_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_f32 = copy.deepcopy(data_type_tf32) + data_type_f32_wo_source = copy.deepcopy(data_type_tf32_wo_source) + data_type_f32["a_type"] = DataType.f32 + data_type_f32["b_type"] = DataType.f32 + data_type_f32["epi_type"] = DataType.f32 + data_type_f32_wo_source["a_type"] = DataType.f32 + data_type_f32_wo_source["b_type"] = DataType.f32 + data_type_f32_wo_source["epi_type"] = DataType.f32 + data_types = [data_type_tf32, data_type_f32, data_type_tf32_wo_source, data_type_f32_wo_source] + + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_tf32_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + + for layout in layouts: + # Inconsistency: TF32 does not stamp out void-C + data_type_tf32 = generate_data_types_from_math_instruction(math_inst) + data_type_f32 = copy.deepcopy(data_type_tf32) + data_type_f32["a_type"] = DataType.f32 + data_type_f32["b_type"] = DataType.f32 + data_type_f32["epi_type"] = DataType.f32 + for data_type in [data_type_tf32, data_type_f32]: + # Inconsistency: alignments aren't fixed in TF32 / alignx + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + ] + + math_instructions = make_sparse_math_instructions(generate_tf32_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + + for layout in layouts: + data_type_tf32 = generate_data_types_from_math_instruction(math_inst) + data_type_tf32_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_f32 = copy.deepcopy(data_type_tf32) + data_type_f32_wo_source = copy.deepcopy(data_type_tf32_wo_source) + data_type_f32["a_type"] = DataType.f32 + data_type_f32["b_type"] = DataType.f32 + data_type_f32["epi_type"] = DataType.f32 + data_type_f32_wo_source["a_type"] = DataType.f32 + data_type_f32_wo_source["b_type"] = DataType.f32 + data_type_f32_wo_source["epi_type"] = DataType.f32 + data_types = [data_type_tf32, data_type_f32, data_type_tf32_wo_source, data_type_f32_wo_source] + + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], + ] + + math_instructions = generate_int8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_int8_output = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.s8, + element_dest=math_inst.element_a, + element_epilogue=DataType.f32 + ) + data_types = [data_type_w_source, data_type_wo_source, data_type_int8_output] + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_int8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_int8_output = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.s8, + element_dest=math_inst.element_a, + element_epilogue=DataType.f32 + ) + data_types = [data_type_w_source, data_type_int8_output] + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], + ] + + math_instructions = make_sparse_math_instructions(generate_int8_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + # s8.u8 and u8.s8 wgmma variants require PTX 8.4 + if math_inst.element_a != math_inst.element_b and not CudaToolkitVersionSatisfies(cuda_version, 12, 4): + continue + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_int8_output = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.s8, + element_dest=math_inst.element_a, + element_epilogue=DataType.f32 + ) + data_types = [data_type_w_source, data_type_wo_source, data_type_int8_output] + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + else: + for d_type in valid_types_for_d: + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + gemm_kind=gemm_kind, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + +def GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) + tile_descriptions_ = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + tile_descriptions = list() + + for desc in tile_descriptions_: + desc.explicit_vector_sizes = [1, desc.tile_shape[1], desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + desc.explicit_vector_sizes = [1, 1, desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + else: + for d_type in valid_types_for_d: + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + gemm_kind=gemm_kind, + enable_fp8_fast_acc=False, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK], + gemm_kind=gemm_kind) + + + +def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=0, default_level=101, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], # TN Layout + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [generate_data_types_from_math_instruction(math_inst)] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + +def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 1): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9999) + is_aligned = True + + # layouts for ABC, their alignments will be fixed later based on the data type + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], + ] + + valid_types_for_a_b_acc = [ + (DataType.e4m3, DataType.f16, DataType.f32), + (DataType.e4m3, DataType.bf16, DataType.f32), + (DataType.e5m2, DataType.f16, DataType.f32), + (DataType.e5m2, DataType.bf16, DataType.f32), + (DataType.s8, DataType.f16, DataType.f32), + (DataType.s8, DataType.bf16, DataType.f32), + (DataType.u8, DataType.f16, DataType.f32), + (DataType.u8, DataType.bf16, DataType.f32), + (DataType.s4, DataType.f16, DataType.f32), + (DataType.s4, DataType.bf16, DataType.f32), + (DataType.s4, DataType.e4m3, DataType.f32), + (DataType.s4, DataType.e5m2, DataType.f32), + (DataType.u4, DataType.f16, DataType.f32), + (DataType.u4, DataType.bf16, DataType.f32), + (DataType.u2, DataType.f16, DataType.f32), + (DataType.u2, DataType.bf16, DataType.f32), + (DataType.s2, DataType.f16, DataType.f32), + (DataType.s2, DataType.bf16, DataType.f32), + ] + # Note: For sizeof(a_type) > sizeof(b_type), some generated kernels might crash due to a compiler bug. Disable it for now. + #swapped_valid_types_for_a_b_acc = [(b_type, a_type, acc_type) for a_type, b_type, acc_type in valid_types_for_a_b_acc] + #valid_types_for_a_b_acc = valid_types_for_a_b_acc + swapped_valid_types_for_a_b_acc + + math_instructions = generate_mixed_dtype_math_instructions_sm90(instantiation_level, valid_types_for_a_b_acc) + + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + + # Limit C/D types to avoid a giant number of instantiations. + # A typical use case for mixed dtype in DL is weight quantization (tensor A), + # therefore we can limit the output type to that of activation (tensor B). + valid_types_for_c = [math_inst.element_b] + valid_types_for_d = [math_inst.element_b] + + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Fix alignments, DataTypeSize are in the unit of bits + alignment_bits = 128 + layout[0][1] = alignment_bits // DataTypeSize[data_type['a_type']] + layout[1][1] = alignment_bits // DataTypeSize[data_type['b_type']] + layout[2][1] = alignment_bits // DataTypeSize[data_type['c_type']] + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = make_sparse_math_instructions(generate_fp8_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + else: + for d_type in valid_types_for_d: + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_1684(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = MathInstruction( + [16, 8, 4], + DataType.f64, DataType.f64, DataType.f64, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateGemmOperator(manifest, layouts, tile_descriptions, + data_type, alignment_constraints) + +# + +# +def GenerateSM90_TensorOp_1684_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64] + + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) + +# + +# +def GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + + +# +def GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_symm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + + + +# Blackwell SM 100 generators + +try: + import cutlass_library.sm100_utils + from cutlass_library.sm100_utils import ( + generate_tf32_math_instructions_sm100, + generate_16b_math_instructions_sm100, + generate_f8f6f4_math_instructions_sm100, + generate_mxf8f6f4_math_instructions_sm100, + generate_mxf4nvf4_math_instructions_sm100, + generate_fp8_math_instructions_sm100, + generate_cluster_shapes_sm100, + get_pruning_level_from_global_level + ) +except ImportError: + import sm100_utils + from sm100_utils import ( + generate_tf32_math_instructions_sm100, + generate_16b_math_instructions_sm100, + generate_f8f6f4_math_instructions_sm100, + generate_mxf8f6f4_math_instructions_sm100, + generate_mxf4nvf4_math_instructions_sm100, + generate_fp8_math_instructions_sm100, + generate_cluster_shapes_sm100, + get_pruning_level_from_global_level + ) + +################################################################################################### + +def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int: + if DataTypeSize[data_type] < 8 and is_f8f6f4: + return int(128) + return int(16 * 8 / DataTypeSize[data_type]) + +sm100_cluster_shape_1sm = [ + [4,4,1] + , DynamicClusterShape +] + +sm100_cluster_shape_2sm = [ + # cluster_m % 2 == 0 for 2sm + [4,4,1] + , DynamicClusterShape +] + +def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4]], + ] + + data_types = [ + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + math_instructions_1sm, math_instructions_2sm = generate_tf32_math_instructions_sm100(instantiation_level) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + +def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) + + # layouts for ABC and their alignments. C alignment will be set later based on output type + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + math_instructions_1sm, math_instructions_2sm = generate_16b_math_instructions_sm100(instantiation_level) + + min_cc = 100 + max_cc = thor_sm + grouped = is_grouped(gemm_kind) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100 + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + if grouped: + epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + elif math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + +def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=591 , default_level=591 , exhaustive_level=9999) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + grouped = is_grouped(gemm_kind) + + math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized1SmSm100, grouped) + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # 2xSM MMA kernels + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + + if grouped: + epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + elif math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + +def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=593, default_level=593, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + min_cc = 100 + max_cc = 100 + epi_type = DataType.f32 + + pruning_level = get_pruning_level_from_global_level(instantiation_level) + + math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_compile_time_dtype=grouped or pruning_level >= 1, enable_runtime_dtype=not grouped) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + tile_schedulers = [ + TileSchedulerType.Default, + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, + [math_inst.instruction_shape[0], math_inst.instruction_shape[1], + math_inst.instruction_shape[2] * 4])) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, + [1, math_inst.instruction_shape[1], + math_inst.instruction_shape[2] * 4])) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, + [math_inst.instruction_shape[0], 1, + math_inst.instruction_shape[2] * 4])) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + + is_runtime_datatype_a = is_runtime_datatype(data_type["a_type"]) + is_runtime_datatype_b = is_runtime_datatype(data_type["d_type"]) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped) + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + epi_schedule_nosmem = to_grouped_schedule(EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm, grouped) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, epi_schedule], [kernel_schedule, epi_schedule_nosmem]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + + # SM100 MMA with mixed F4/F6/F8 inputs + without block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + math_instructions_1sm, math_instructions_2sm = generate_f8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + for kernel_data_type in kernel_data_types: + # Filter out some kernel + if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\ + ( kernel_data_type["d_type"] == DataType.e5m2 ): + continue + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers) + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + for kernel_data_type in kernel_data_types: + # Filter some kernel + if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\ + ( kernel_data_type["d_type"] == DataType.e5m2 ): + continue + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + if math_inst.instruction_shape[0] == 128: + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], tile_schedulers=tile_schedulers) + else: + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers) + +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + + # SM100 MMA with mixed F4/F6/F8 inputs + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]], + ] + + math_instructions_1sm, math_instructions_2sm = generate_mxf8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) + + ab_types = [ + DataType.f4, DataType.f6, + DataType.e2m1, + DataType.e2m3, + DataType.e3m2, + DataType.e5m2, + DataType.e4m3, + ] + + acc_types = [ DataType.f32 ] + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if sfdtype["type"] == DataType.void or grouped: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + for math_inst in math_instructions_2sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + ] + + # Set alignment d based on Destination format. + for data_type in data_types: + for layout in layouts: + # alignment for a + layout[0][1] = get_tma_alignment_elt(data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(data_type["d_type"]) + for tile in tile_descriptions: + math_inst = tile.math_instruction + # Filter some kernels that does not meet the alignment requirements. + if layout[0][0] == LayoutType.ColumnMajor: + if math_inst.instruction_shape[0] // 2 % layout[0][1] != 0: + continue + else: + if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[0][1] != 0: + continue + + if layout[1][0] == LayoutType.RowMajor: + if math_inst.instruction_shape[1] // 2 % layout[1][1] != 0: + continue + else: + if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[1][1] != 0: + continue + + if grouped: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + elif math_inst.instruction_shape[0] == 128: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + else: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + + +def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + # SM100 MMA with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=591, default_level=591, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], + ] + + math_instructions_1sm, math_instructions_2sm = generate_mxf4nvf4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func=change_priority_func) + + acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if sfdtype["type"] == DataType.void or grouped: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + assert math_inst.instruction_shape[2] * 4 == 256 + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for layout in layouts: + for data_type in data_types: + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + + # E2M1 x E2M1, vector size 32, E8 + # E2M1 x E2M1, vector size 16, UE4M3 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) + nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped) + fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped) + + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + + for math_inst in math_instructions_2sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for layout in layouts: + for data_type in data_types: + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) + nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped) + fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped) + + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + +def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + # SM100 MMA with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 13, 0): + return + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], + ] + + instruction_sizes_1sm = [ + [128, 128, 96], + ] + + instruction_sizes_2sm = [ + [256, 128, 96], + [256, 192, 96], + [256, 256, 96] + ] + + ab_types = [ + DataType.f4, + DataType.e2m1, + ] + + sf_types = [ + DataType.ue4m3, + DataType.ue8m0 + ] + + acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if grouped: + return [TileSchedulerType.Default] + if sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 103 + max_cc = 103 + epi_type = DataType.f32 + + math_instructions_1sm = [] + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, sf_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_1sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + sf_type) + ) + + math_instructions_2sm = [] + + for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, sf_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_2sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + sf_type) + ) + + cluster_shapes_1sm = [ + [1,1,1], + # [1,2,1], + [2,1,1], + # [1,4,1], + [4,4,1], + DynamicClusterShape + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + 768], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + for data_type in data_types: + # Set alignment d based on Destination format. + if DataTypeSize[data_type["c_type"]] == 0 : + layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] + else: + layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) + + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + epilogue_1sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) + + nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), epilogue_1sm_schedule] + nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] + nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] + fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), epilogue_1sm_schedule] + fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] + fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] + nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] + fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] + + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + cluster_shapes_2sm = [ + [2,1,1], + # [2,2,1], + # [2,4,1], + [4,1,1], + # [4,2,1], + [4,4,1], + DynamicClusterShape + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 8 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + for data_type in data_types: + # Set alignment d based on Destination format. + if DataTypeSize[data_type["c_type"]] == 0 : + layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] + else: + layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) + + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + epilogue_2sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) + + nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), epilogue_2sm_schedule] + nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] + nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] + fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), epilogue_2sm_schedule] + fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] + fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] + nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] + fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] + + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + +def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + math_instructions_1sm = [ + MathInstruction( + [64, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add)] + + cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] + , DynamicClusterShape + ] + + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1] + , DynamicClusterShape + ] + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + math_instructions_2sm = [ + MathInstruction( + [128, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + , DynamicClusterShape + ] + + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] + , DynamicClusterShape + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + +def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # void_c + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + MathInstruction( + [128, 128, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + MathInstruction( + [256, 128, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # void_c + { + "a_type" : DataType.f16, + "b_type" : DataType.f16, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : DataType.f16, + "b_type" : DataType.f16, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + MathInstruction( + [128, 128, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + MathInstruction( + [256, 128, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # void_c + { + "a_type" : DataType.s8, + "b_type" : DataType.s8, + "c_type" : DataType.void, + "d_type" : DataType.s8, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : DataType.s8, + "b_type" : DataType.s8, + "c_type" : DataType.s8, + "d_type" : DataType.s8, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + MathInstruction( + [128, 128, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add)] + + math_instructions_2sm = [ + MathInstruction( + [256, 128, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # NOTE: a/b type in kernel will be overwrite below. + #* void_c + # f8_f8_f32_void_f16 + { + "a_type" : DataType.e4m3, + "b_type" : DataType.e4m3, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + #* non-void_c + # f8_f8_f32_f16_f8 + { + "a_type" : DataType.e4m3, + "b_type" : DataType.e4m3, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + # Runtime DType + MathInstruction( + [128, 128, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + # Runtime DType + MathInstruction( + [256, 128, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update input AB type + kernel_data_type["a_type"] = math_inst.element_a + kernel_data_type["b_type"] = math_inst.element_b + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update input AB type + kernel_data_type["a_type"] = math_inst.element_a + kernel_data_type["b_type"] = math_inst.element_b + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + math_instructions_1sm = [ + # Runtime Dtype + MathInstruction( + [128, 128, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + + MathInstruction( + [128, 128, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + # Runtime DType + MathInstruction( + [256, 128, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + + MathInstruction( + [256, 128, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + # void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + ] + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_filtered = [] + for layout in layouts: + layout_filter = copy.deepcopy(layout) + # * A_K : Logical TileShape_K % 256 == 0 + # * A_M : TileShape_M % 128 == 0 + # * B_N : TileSize_N % 128 == 0 + # * B_K : TileSize_K % 128 == 0 + if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ + (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ + ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 128 == 0) or \ + (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): + # alignment for a, 2 for sparsity + layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + layouts_filtered.append(layout_filter) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + # void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + ] + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_filtered = [] + for layout in layouts: + layout_filter = copy.deepcopy(layout) + # * A_K : Logical TileShape_K % 256 == 0 + # * A_M : TileShape_M % 128 == 0 + # * B_N : TileSize_N % 256 == 0 + # * B_K : TileSize_K % 128 == 0 + if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ + (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ + ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 256 == 0) or \ + (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): + # alignment for a, 2 for sparsity + layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + layouts_filtered.append(layout_filter) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +# Conv Utility functions +def make_dims_and_alignments_triple(dim: int, bit_per_element_A: int, bit_per_element_B: int, bit_per_element_C: int): + bit_alignment_required_by_tma = 128 + return ((dim, bit_alignment_required_by_tma // bit_per_element_A), # A + (dim, bit_alignment_required_by_tma // bit_per_element_B), # B + (dim, bit_alignment_required_by_tma // bit_per_element_C)) # C + +def make_math_instruction_w_output(data_types: Tuple[DataType, DataType, DataType, DataType], + instruction_shape: Tuple[int, int, int]) -> (MathInstruction, DataType): + default_opcode = OpcodeClass.TensorOp + default_math_op = MathOperation.multiply_add + [A_data_type, B_data_type, Acc_data_type, Out_data_type] = data_types + return (MathInstruction( + instruction_shape, + A_data_type, B_data_type, Acc_data_type, + default_opcode, + default_math_op + ), Out_data_type) + +""" +Generate CUTLASS 3 convolution kernel(s) for SM100. + +This is meant to be called from GenerateSM100. +""" +def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, + log_indent_level: int = 0): + log_debug_line('GenerateSM100_TensorOp_16b_UMMA_conv3x', log_indent_level) + log_indent_level = log_indent_level + 1 + + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + thor_sm = ThorSMRenumbering(cuda_version) + + minimum_compute_capability = 100 + maximum_compute_capability = thor_sm + + spatial_dims = [2, 3] + + conv_kinds = [ + ConvKind.Fprop, + ConvKind.Dgrad, + ConvKind.Wgrad + ] + + stages = 0 # zero means "deduce the number of stages automatically" + + data_types_and_instruction_shapes_1sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (64, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (64, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 256, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (64, 128, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 128, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 256, 16)), + ] + math_instructions_w_output_1sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_1sm) + + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]] + + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]] + + # tile_descriptions is a 2-level list. + # Each inner list is for each cluster shape. + for math_inst, output_type in math_instructions_w_output_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + cluster_multiplier = cluster_shape + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + # It's typical to get the data types from the math instruction. + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + for conv_kind in conv_kinds: + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = conv_kind, + log_indent_level = log_indent_level) + + data_types_and_instruction_shapes_2sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (256, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (256, 256, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 128, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 256, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (256, 256, 16)), + ] + math_instructions_w_output_2sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_2sm) + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]] + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]] + + for math_inst, output_type in math_instructions_w_output_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + cluster_multiplier = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + # It's typical to get the data types from the math instruction. + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + for conv_kind in conv_kinds: + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = conv_kind, + log_indent_level = log_indent_level) + +def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, + log_indent_level: int = 0): + # Instantiate Fp8 Fprop kernels with e4m3 A/B, f32 Acc, e4m3/bf16/f16/f32 C/D + log_debug_line('GenerateSM100_TensorOp_fp8_UMMA_conv3x', log_indent_level) + log_indent_level = log_indent_level + 1 + + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + thor_sm = ThorSMRenumbering(cuda_version) + + minimum_compute_capability = 100 + maximum_compute_capability = thor_sm + + spatial_dims = [2, 3] + stages = 0 # zero means "deduce the number of stages automatically" + + data_types_and_instruction_shapes_1sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 256, 32)), + ] + math_instructions_w_output_1sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_1sm) + + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]] + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]] + + for math_inst, output_type in math_instructions_w_output_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + cluster_multiplier = cluster_shape + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = ConvKind.Fprop, + log_indent_level = log_indent_level) + + data_types_and_instruction_shapes_2sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (256, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (256, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (256, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (256, 256, 32)), + ] + math_instructions_w_output_2sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_2sm) + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]] + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]] + + for math_inst, output_type in math_instructions_w_output_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + cluster_multiplier = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = ConvKind.Fprop, + log_indent_level = log_indent_level) + +def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # SM120 MMA with mixed F4/F6/F8 inputs + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]] + ] + + instruction_sizes = [ + [16, 8, 32] + ] + + tile_sizes = [ + [128, 128, 128] + ] + + cluster_shape = [1,1,1] + + ab_types = [ + DataType.e2m1, + DataType.e2m3, + DataType.e3m2, + DataType.e5m2, + DataType.e4m3, + ] + + acc_types = [ DataType.f32 ] + + def is_pingpong(kernel_schedule): + if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: + return True + else: + return False + + def tile_schedulers(sfdtype, kernel_schedule): + # Pingpong kernel schedule doesn't support stream-K. + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K + if is_pingpong(kernel_schedule): + return [TileSchedulerType.Default] + elif sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 120 + max_cc = 121 + + epi_type = DataType.f32 + + math_instructions = [] + + kernel_schedules = [ + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120, + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120 + ] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes, ab_types, ab_types, acc_types): + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + for math_inst in math_instructions: + tile_descriptions = [] + for tile_size in tile_sizes: + tile_descriptions.append( + TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type, kernel_schedule in product(data_types, kernel_schedules): + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule), + gemm_kind = GemmKind.BlockScaledUniversal3x + ) + +def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # SM120 MMA with with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]] + ] + + instruction_sizes = [ + [16, 8, 64] + ] + + tile_sizes_cooperative = [ + [128, 128, 128], + [128, 128, 256], + [256, 128, 128] + ] + + tile_sizes_pingpong = [ + [128, 128, 128], + [128, 128, 256] + ] + + cluster_shape = [1,1,1] + + ab_types = [ + DataType.e2m1 + ] + + sf_types = [ + DataType.ue4m3, + DataType.ue8m0 + ] + + acc_types = [ DataType.f32 ] + + def is_pingpong(kernel_schedule): + if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120 or \ + kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: + return True + else: + return False + + def is_nvf4(kernel_schedule): + if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120 or \ + kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: + return True + else: + return False + + def tile_schedulers(sfdtype, kernel_schedule): + # Pingpong kernel schedule doesn't support stream-K. + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K + if is_pingpong(kernel_schedule): + return [TileSchedulerType.Default] + elif sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 120 + max_cc = 121 + + epi_type = DataType.f32 + + math_instructions = [] + + kernel_schedules = [ + KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120, + KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120, + KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120, + KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120 + ] + + for instr_size, a_type, b_type, acc_type, sf_type in product(instruction_sizes, ab_types, ab_types, acc_types, sf_types): + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + sf_type) + ) + + for math_inst in math_instructions: + for kernel_schedule in kernel_schedules: + tile_descriptions = [] + tile_sizes = tile_sizes_pingpong if is_pingpong(kernel_schedule) else tile_sizes_cooperative + for tile_size in tile_sizes: + # nvf4 kernel only supports ue4m3 SF + # mxf4 kernel only supports ue8m0 SF + if (math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \ + (math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)): + tile_descriptions.append( + TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule), + gemm_kind = GemmKind.BlockScaledUniversal3x + ) + +def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 256], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]] + ] + + tile_sizes = [ + [128, 128, 256] + ] + + cluster_shape = [1,1,1] + + warp_count = [4, 2, 1] + + acc_types = [ DataType.f32 ] + + instruction_sizes_mxf8f6f4 = [ + [16, 8, 64] + ] + + ab_types_mxf8f6f4 = [ + DataType.e2m1, + #DataType.e2m3, + DataType.e3m2, + #DataType.e5m2, + DataType.e4m3, + ] + + def tile_schedulers(kernel_schedule): + return [TileSchedulerType.Default] + + min_cc = 120 + max_cc = 121 + + kernel_schedules = [ + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120, + ] + + math_instructions_mxf8f6f4 = [] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_mxf8f6f4, ab_types_mxf8f6f4, ab_types_mxf8f6f4, acc_types): + math_instructions_mxf8f6f4.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add) + ) + + # Create gemm operator for mxf8f6f4 + for math_inst in math_instructions_mxf8f6f4: + tile_descriptions_mxf8f6f4 = [] + for tile_size in tile_sizes: + tile_descriptions_mxf8f6f4.append( + TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + ] + + for data_type, kernel_schedule in product(data_types, kernel_schedules): + # Set alignment d based on Destination format + for layout in layouts: + layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]]) + # Create gemm operator + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_mxf8f6f4, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(kernel_schedule), + gemm_kind = GemmKind.SparseUniversal3x) + +def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 16]], + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 16]] + ] + + cooperative_tile_sizes = [ + [128, 128, 128] + ] + pingpong_tile_sizes = [ + [64, 128, 128] + ] + + def get_tile_sizes(kernel_scheduler): + if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: + return pingpong_tile_sizes + return cooperative_tile_sizes + + def get_warp_count(kernel_scheduler): + if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: + return [2, 2, 1] + return [4, 2, 1] + + def get_sf_sizes(tile_size): + sf_sizes = [] + for vec_m in [1, 128]: + if tile_size[0] % vec_m > 0: + continue + for vec_n in [1, 128]: + if tile_size[1] % vec_m > 0: + continue + sf_sizes.append( + [vec_m, vec_n, 128] + ) + return sf_sizes + + cluster_shape = [1,1,1] + + acc_types = [ DataType.f32 ] + + instruction_sizes = [ + [16, 8, 32] + ] + + def tile_schedulers(kernel_schedule): + return [TileSchedulerType.Default] + + min_cc = 120 + max_cc = 121 + + kernel_schedulers = [ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120, + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120 + ] + + ab_types = [ + [DataType.e4m3, DataType.e4m3], + [DataType.e4m3, DataType.e5m2] + ] + + math_instructions = [] + + for instr_size, ab_type, acc_type in product(instruction_sizes, ab_types, acc_types): + a_type, b_type = ab_type + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + # Create gemm operator for mxf8f6f4 + for kernel_schedule in kernel_schedulers: + tile_sizes = get_tile_sizes(kernel_schedule) + warp_count = get_warp_count(kernel_schedule) + for math_inst in math_instructions: + tile_descriptions = [] + for tile_size in tile_sizes: + sf_sizes = get_sf_sizes(tile_size) + for sf_size in sf_sizes: + tile_descriptions.append( + TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape, + explicit_vector_sizes=sf_size) + ) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + ] + + for data_type in data_types: + # Set alignment d based on Destination format + for layout in layouts: + layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]]) + # Create gemm operator + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(kernel_schedule), + gemm_kind = gemm_kind) + +def GenerateSM100(manifest, cuda_version): + arch_family_cc = ['100f', '101f', '103a'] + if CudaToolkitVersionSatisfies(cuda_version, 13, 0): + for old_cc, new_cc in [('101f', '110f')]: + arch_family_cc = [cc.replace(old_cc, new_cc) for cc in arch_family_cc] + + # + # Dense Gemm + # + GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version) + + GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version) + + if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): + GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version) + + GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version) + # grouped GEMM + GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + + # StreamK is included in regular generation + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) + + # Blockwise kernels + GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version) + GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) + + # + # Sparse Gemm + # + GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version) + GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version) + if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): + GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version) + GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version) + GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) + + # + # Block Scaled Gemm + # + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + + GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + # + # Conv + # + GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version) + GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version) + + +def GenerateSM120(manifest, cuda_version): + # StreamK is included in regular generation # + # + # Dense Block Scaled Gemm + # + GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) + + # + # Sparse Gemm + # + GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version) + GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version) + GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) + +################################################################################################### + +def GenerateSM90_Conv3x(manifest, cuda_version, + log_indent_level: int = 0): + """ + Generate CUTLASS 3 convolution kernel(s) for SM90. + + This is meant to be called from GenerateSM90. + """ + log_debug_line('GenerateSM90_Conv3x', log_indent_level) + log_indent_level = log_indent_level + 1 + + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + minimum_compute_capability = 90 + maximum_compute_capability = 90 + + spatial_dims = (2, 3) + + # MMA shapes (MMA_M, MMA_N, MMA_K): + # + # Different hardware MMA instructions may have different MMA shapes. + # This function may generate kernels with different MMA shapes for + # different data types, either because the hardware only supports + # certain shapes for certain types, or for performance reasons + # (CUTLASS doesn't need to generate all valid kernels for the + # profiler library, just the best-performing ones). + # + # The kernel names refer to tile shapes (TILE_M, TILE_N, TILE_K) + # instead of MMA shapes. For SM >= 90 kernels, TILE_K = 4 * MMA_K, + # where 4, the "number of MMA instructions per tile," is determined + # through some combination of modeling and experiment. + # + # For performance on sm90, generally CUTLASS generates 64x128 + # instead of 128x64. + mma_64x64x16 = ( 64, 64, 16) + mma_64x64x8 = ( 64, 64, 8) + + num_mma_per_tile = 4 + + # Cluster shapes (1, 1, 1) and (2, 2, 1) are valid, + # but not included, because they tend not to perform as well. + cluster_shapes = ( + (2, 1, 1), + (1, 2, 1), + ) + + fp16 = DataType.f16 + bf16 = DataType.bf16 + fp32 = DataType.f32 + s8 = DataType.s8 + s32 = DataType.s32 + + # When generating kernels, the usual way is to specify 4 types, + # (A, B, Acc, C/D). Tests instead have 5 types, + # (ElementAct, ElementFlt, ElementOut, ElementAcc, ElementCompute), + # where ElementCompute is also called 'epi_type', + # and corresponds to the type of epilogue activations. + # This script maps tests' 5 types to 4 types + # by making ElementCompute the same as ElementOut. + + fp16_fp32_fp16_fp32 = { + 'a_type': fp16, # ElementAct(ivation) + 'b_type': fp16, # ElementF(i)lt(er) + 'c_type': fp32, # ElementAcc + 'd_type': fp32, # ElementOut (used only by CollectiveEpilogue) + 'acc_type': fp16, # ElementAcc + 'epi_type': fp32, # ElementCompute (used only by CollectiveEpilogue) + 'alignment_A': 8, # tma alignment elements of A + 'alignment_B': 8, # tma alignment elements of B + 'alignment_C': 4, # tma alignment elements of C + } + fp16_fp32_fp32_fp32 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 4, + } + fp32_fp32_fp32_fp32 = { + 'a_type': fp32, + 'b_type': fp32, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 4, + 'alignment_B': 4, + 'alignment_C': 4, + } + s8_s32_s32_s32 = { + 'a_type': s8, + 'b_type': s8, + 'c_type': s32, + 'd_type': s32, + 'acc_type': s32, + 'epi_type': s32, + 'alignment_A': 16, + 'alignment_B': 16, + 'alignment_C': 4, + } + + # Other NVIDIA libraries may have the habit of specifying data types like this. + bf16bf16_bf16f32_f32 = { + 'a_type': bf16, + 'b_type': bf16, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 4, + } + f16f16_f16f16_f16 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp16, + 'd_type': fp16, + 'acc_type': fp16, + 'epi_type': fp16, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 8, + } + f16f16_f16f32_f32 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp16, + 'd_type': fp16, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 8, + } + f32f32_tf32f32_f32 = fp32_fp32_fp32_fp32 + + i8i8_i8i32_f32 = { + 'a_type': s8, + 'b_type': s8, + 'c_type': s32, + 'd_type': s32, + 'acc_type': s32, + 'epi_type': s32, + 'alignment_A': 16, + 'alignment_B': 16, + 'alignment_C': 4, + } + + # Each element in the outermost iterable is one combination of + # + # (ConvKind, spatial_dimension, data_types, byte_alignments, mma_sizes, cluster_sizes) + # + # for which to generate a kernel. spatial_dimension is the spatial + # dimension of the convolution: either 1, 2, or 3. byte_alignments + # is a triple of required minimum byte alignments for A, B, and C. + # + # Note that itertools functions produce a single-pass generator. + # The code doesn't need a multipass iterable, but if one did, one + # could call `tuple` or `list` on the generator. + # + # While this happens to use the same cluster sizes for each element, + # the code doesn't require that. Different convolution kinds, data + # types, or mma sizes might have different optimal cluster sizes. + combinations_of_parameters = chain( + # The following are all the kernels exercised in the unit tests. + # Please try to keep in sync with the unit tests. + product( + ( + ConvKind.Fprop, + ), + spatial_dims, + ( + fp16_fp32_fp16_fp32, + fp16_fp32_fp32_fp32, + s8_s32_s32_s32, + ), + ( + mma_64x64x16, + ), + cluster_shapes + ), + product( + ( + ConvKind.Fprop, + ), + spatial_dims, + ( + fp32_fp32_fp32_fp32, + ), + ( + mma_64x64x8, + ), + cluster_shapes + ), + product( + ( + ConvKind.Dgrad, + ConvKind.Wgrad + ), + spatial_dims, + ( + fp16_fp32_fp16_fp32, + fp16_fp32_fp32_fp32, + ), + ( + mma_64x64x16, + ), + cluster_shapes + ), + # Kernels not necessarily in the unit tests, but used elsewhere + # and thus useful to have generated for profiling. They may + # duplicate kernels above. All of them are 2-D. In general, + # CUTLASS prefers 64 x 128 to 128 x 64 on sm90, even if the + # hardware permits 128 x 64. + ( + # Fprop + # + # bf16bf16_bf16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (256, 128, 16), (2, 1, 1)), + # + # f16f16_f16f16_f16 + # + # cluster shape (1, 1, 1) + # + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 64, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 64, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 128, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 256, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 256, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 128, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 256, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 256, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 64, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 64, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 128, 16), (1, 1, 1)), + # + # f16f16_f16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 192, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 192, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 96, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 96, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 128, 16), (2, 1, 1)), + # + # f32f32_tf32f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (128, 192, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (256, 96, 8), (2, 1, 1)), + # + # i8i8_i8i32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (128, 256, 32), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (256, 128, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (256, 128, 32), (2, 1, 1)), + # + # Dgrad + # + # bf16bf16_bf16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (256, 128, 16), (2, 1, 1)), + # + # f16f16_f16f16_f16 + # + # cluster shape (1, 1, 1) + # + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 64, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 64, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 128, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 256, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 256, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 128, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 256, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 256, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 64, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 64, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 128, 16), (1, 1, 1)), + # + # f16f16_f16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (256, 128, 16), (2, 1, 1)), + ), + ) + + # SM >= 90 kernels don't actually use warp_count, but the + # TileDescription class needs it. The 4 in the default + # warp_count has nothing to do with num_mma_per_tile. + warp_count = [4, 1, 1] + + stages = 0 # zero means "deduce the number of stages automatically" + + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecializedSm90 + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized + schedule_pairs = ( + (mainloop_schedule, epilogue_schedule), + ) + tile_schedulers = ( + TileSchedulerType.Default, # -> void + ) + + def make_math_instruction(data_types: Dict[str, DataType], + mma_shape: Tuple[int, int, int]) -> MathInstruction: + default_opcode = OpcodeClass.TensorOp + default_math_op = MathOperation.multiply_add + return MathInstruction( + mma_shape, + data_types['a_type'], data_types['b_type'], data_types['c_type'], + default_opcode, + default_math_op + ) + + for (conv_kind, spatial_dim, data_types, mma_shape, cluster_shape) in combinations_of_parameters: + math_inst = make_math_instruction(data_types, mma_shape) + tile_shape = (mma_shape[0], mma_shape[1], num_mma_per_tile * mma_shape[2]) + tile_description = TileDescription(tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, cluster_shape) + assert(isinstance(spatial_dim, int)) + dims_and_alignments = ( + ( + (spatial_dim, data_types['alignment_A']), + (spatial_dim, data_types['alignment_B']), + (spatial_dim, data_types['alignment_C']), + ), + ) + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = [tile_description], + data_types = data_types, + schedule_pairs = schedule_pairs, + tile_schedulers = tile_schedulers, + conv_kind = conv_kind, + log_indent_level = log_indent_level) + +def GenerateSM90(manifest, cuda_version): + GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_1684(manifest, cuda_version) + GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM90_TensorOp_1684_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version) + GenerateSM90_Conv3x(manifest, cuda_version) + GenerateSM90_SparseTensorOp_16b_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) + +################################################################################################### + +def numeric_log_level(log_level: str) -> int: + """ + Converts the string identifier of the log level + into the numeric identifier used in setting the log level. + + :param x: string representation of log level (e.g., 'INFO', 'DEBUG') + :type x: str + + :return: numeric representation of log level + :rtype: int + """ + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f'Invalid log level: {log_level}') + return numeric_level + +# This function for defining the ArgumentParser is used to make it easy for the CUTLASS Python interface +# to leverage the functionality in this file without running this script via a shell prompt. +def define_parser(): + parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels") + parser.add_argument("--operations", default="all", help="Specifies the operation to generate (gemm, all)") + parser.add_argument("--build-dir", default=".", required=False, help="CUTLASS top-level build directory") + parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory") + parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") + parser.add_argument("--architectures", default='53;60;61;70;75;80;90;100', help="Target compute architectures") + parser.add_argument("--kernels", default='', help='Comma-delimited list to filter kernels by name. ' + + 'Specifying this as \"all\" includes ALL the kernels, ' + + 'while not specifying this includes only the default set of kernels.') + parser.add_argument("--ignore-kernels", default='', help='Comma-delimited list of kernels ' + + 'to exclude from build. For backwards compatibility reasons, ' + + 'this option only takes effect if --kernels is set to a nonempty value.') + parser.add_argument("--exclude-kernels", default='', help='Comma-delimited list of kernels ' + + 'to exclude from build. In contrast to --ignore-kernels, ' + + 'this option always takes effect, ' + + 'whether or not --kernels is set to a nonempty value. ' + + 'It also can exclude kernels from the filter file ' + + '(see --kernel-filter-file option below).') + parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.') + parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit") + parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file') + parser.add_argument('--heuristics-problems-file', type=str, default=None, required=False, help='Full path of heuristics problem size description file, as a json list') + parser.add_argument('--heuristics-testlist-file', type=str, default=None, required=False, help='Full path of heuristics testlist CSV file, to be passed to cutlass_profiler') + parser.add_argument('--heuristics-gpu', type=str, default=None, required=False, help='GPU to use for evaluating heuristics offline. None or `auto` to autodetect using cuda', choices=['', 'auto', 'H100_SXM', 'H100_PCIE', 'H100_NVL', 'H200_SXM', 'H20_SXM', 'B200', 'GB200_NVL', 'RTX_5080', 'RTX_5090', 'RTX_PRO_6000']) + parser.add_argument('--heuristics-configs-per-problem', type=int, default=10, required=False, help='Number of kernel configs to generate for each problem in the problem list') + parser.add_argument('--heuristics-restrict-kernels', action='store_true', help='Restrict heuristics mode to use only the default set of kernels emitted by generator.py') + parser.add_argument('--selected-kernel-list', type=str, default=None, required=False, + help='Specify the output log file containing all enabled kernels in this build') + parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels") + parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures") + parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, + help='Logging level to be used by the generator script') + parser.add_argument('--instantiation-level', type=str, default="", required=False, help="Instantiation level for SM90 kernels. Set to `max` and make sure `--kernels` is not empty to generate all possible configurations.") + _add_package_disablement_flag(parser) + return parser + + +if __name__ == "__main__": + parser = define_parser() + args = parser.parse_args() + + # Set the logging level based on the user-provided `--log-level` command-line option + logging.basicConfig(level=args.log_level) + + manifest = Manifest(args) + + archs = args.architectures.split(';') + + if args.heuristics_problems_file: + filter_manifest_and_write_heuristics_file(manifest, args) + + GenerateSM50(manifest, args.cuda_version) + GenerateSM60(manifest, args.cuda_version) + GenerateSM61(manifest, args.cuda_version) + GenerateSM70(manifest, args.cuda_version) + GenerateSM75(manifest, args.cuda_version) + GenerateSM80(manifest, args.cuda_version) + GenerateSM89(manifest, args.cuda_version) + GenerateSM90(manifest, args.cuda_version) + + blackwell_arch_list = [ + "100a", "100f", + "101a", "101f", + "103a", "103f", + "110a", "110f", + "120a", "120f", + "121a", "121f", + ] + blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs) + if blackwell_enabled_arch: + GenerateSM100(manifest, args.cuda_version) + GenerateSM120(manifest, args.cuda_version) + + if 'library' in args.generator_target.split(','): + manifest.emit(GeneratorTarget.Library) + + if 'kernel_testlist_l0' in args.generator_target.split(','): + emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L0") + + if 'kernel_testlist_l1' in args.generator_target.split(','): + emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L1") + + if args.selected_kernel_list is not None: + if len(manifest.selected_kernels) > 0: + with open(args.selected_kernel_list, 'w') as file_writer: + for line in manifest.selected_kernels: + file_writer.write("%s\n" % line) + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..83421a06427acdc3b059855991cf95a1d2f118b3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py @@ -0,0 +1,415 @@ +################################################################################################# +# +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for selecting CUTLASS library kernels based on problem description +""" +import json +import csv + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.generator import * + from cutlass_library.heuristics_provider import * +except ImportError: + from library import * + from generator import * + from heuristics_provider import * + +try: + from .sm90_utils import ( + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) +except ImportError: + from sm90_utils import ( + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) + +_LOGGER = logging.getLogger(__name__) + +dtype_map = {v: k for k, v in DataTypeNames.items()} + +def serialize_heuristics_results_to_json(problems_with_configs, outfile_path): + """ + Utilitiy function to write heuristics results to a json file for debug + + args: + problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict + outfile_path: Outfile path + + returns: + None + """ + pc_copy = problems_with_configs.copy() + for p in pc_copy: + for k, v in p.items(): + if isinstance(v, DataType): + p[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + p[k] = ShortLayoutTypeNames[v] + configs = p['configs'] + for c in configs: + for k, v in c.items(): + if isinstance(v, DataType): + c[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + c[k] = ShortLayoutTypeNames[v] + with open(outfile_path, 'w') as f: + json.dump(pc_copy, f, indent=2) + +def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None): + """ + Get heuristic-suggested GEMM kernel configurations for a single GEMM problem. + + args: + m, n, k: GEMM dimensions + batch_count: batch count + layouts: tuple of layouts of type LayoutType + use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions + count: Number of configs to return + provider: Heuristics provider to use + + returns: + A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys: + - 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size + - 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size + - 'stages': kernel pipeline stage count + - 'cluster_m', 'cluster_n', 'cluster_k': cluster size + - 'layout_a', 'layout_b': input tensor layouts of type LayoutType + - 'alignment_a', 'alignment_b': input tensor alignments, in count of elements + - 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType + - 'swizzle_size' : suggested threadblock swizzle + - 'split_k_slices': number of partitions of the k dimension for splitK + - 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n') + """ + if provider is None: + provider = MatmulHeuristics() + return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count) + +def get_gemm_configs(problems, provider=None, count=1): + """ + Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems. + + args: + problems: List of dictionaries describing GEMM problems with the following keys: + - 'm', 'n', 'k': Matrix dimensions (required) + - 'dtype_a': Data type of matrix A (required) + - 'dtype_b': Data type of matrix B (required) + - 'dtype_c': Data type of matrix C (default: None) + - 'dtype_d': Data type of matrix D (required) + - 'dtype_acc': Compute data type (default 'f32') + - 'layout': Operation layout (e.g. 'tnt') + - 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements) + - 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements) + - 'alpha': Scalar multiplier for A*B (default: 1.0) + - 'beta': Scalar multiplier for C (default: 0.0) + - 'batch_count': Number of GEMM operations in batch (default: 1) + - 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True) + provider: Heuristics provider to use + count: Number of configurations to return per problem (defualt: 1) + + returns: + A copy of the input dictionary, with key `configs` added containing the selected gemm configs + """ + ret = [] + + for problem in problems: + problem = problem.copy() + + try: + m = problem['m'] + n = problem['n'] + k = problem['k'] + dtype_a = problem['dtype_a'] + dtype_b = problem['dtype_b'] + dtype_d = problem['dtype_d'] + layout = problem['layout'] + except KeyError as e: + _LOGGER.error(f"Missing required parameter {e} for problem {problem}") + raise + + operation = problem.get('operation', 'gemm') + batch_count = problem.get('batch_count', 1) + dtype_acc = problem.get('dtype_acc', 'f32') + dtype_c = problem.get('dtype_c', None) + alpha = problem.get('alpha', 1.0) + beta = problem.get('beta', 0.0) + use_fast_acc = problem.get('use_fast_acc', True) + + if operation != OperationKindNames[OperationKind.Gemm]: + raise ValueError(f"Unsupported operation {operation}") + if not (len(layout) == 3 and all(c in "nt" for c in layout)): + raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}") + layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout) + + try: + dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()] + dtypes = tuple(dtype_map[dt] for dt in dtype_list) + except KeyError as dt: + _LOGGER.error(f"Unsupported data type: {dt}") + raise + + alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]]) + alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]]) + + configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider) + problem['configs'] = configs + + ret.append(problem) + + return ret + + +def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs): + """ + Generate CUTLASS operations based on the list of configs provided by the heuristic provider + + args: + manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest) + cuda_version: Cuda compiler version for generating cutlass operations + kernel_configs: list of configs generated by the heuristic + + returns: + (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations + """ + min_cc = 100 + max_cc = 101 + if manifest is None: + # Use a dummy manifest so we can use existing CreateGemmOperator functions + manifest = Manifest() + + configs = [] + operations = [] + for config in kernel_configs: + layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]]) + element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d'] + + # nvMMH assumes 2sm instruction for !(cluster_m % 2) + is_2sm = config['cluster_m'] % 2 == 0 + instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4] + math_instruction = MathInstruction( + instruction_shape, + element_a, element_b, element_accumulator, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ) + + data_types = [ + { + "a_type" : math_instruction.element_a, + "b_type" : math_instruction.element_b, + "c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator, + "d_type" : element_d, + "acc_type" : math_instruction.element_accumulator, + "epi_type" : math_instruction.element_accumulator, + } + ] + + tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k']) + tile_description = TileDescription( + [instruction_shape[0] * tile_multiplier[0], + instruction_shape[1] * tile_multiplier[1], + instruction_shape[2] * 4 * tile_multiplier[2]], + 0, + [4,1,1], + math_instruction, + min_cc, + max_cc, + cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k']) + ) + + schedules = [] + if is_2sm: + schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]) + else: + schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]) + + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x): + configs.append(config) + operations.append(o) + + + return configs, operations + + +def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs): + """ + Generate CUTLASS operations based on the list of configs provided by the heuristic provider + + args: + manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest) + cuda_version: Cuda compiler version for generating cutlass operations + kernel_configs: list of configs generated by the heuristic + + returns: + (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations + """ + min_cc, max_cc = 90, 90 + + if manifest is None: + # Use a dummy manifest so we can use existing CreateGemmOperator functions + manifest = Manifest() + + configs = [] + operations = [] + for config in kernel_configs: + + is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128) + layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1]) + element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d'] + + # instr shape and warp config are unused for emitting 3x collective builder code + dummy_instr_shape = [0, 0, 0] + math_instruction = MathInstruction( + dummy_instr_shape, + element_a, element_b, element_accumulator, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ) + + data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d) + if is_aligned: + layout = fix_alignments(data_types, layout, alignment_bits=128) + + # instr shape and warp config are unused for emitting 3x collective builder code + dummy_warp_count = [0, 0, 0] + tile_description = TileDescription( + [config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']], + 0, + dummy_warp_count, + math_instruction, + min_cc, + max_cc, + cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k']) + ) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_description, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_types, + instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic + layout=layout, + gemm_kind=GemmKind.Universal3x, + enable_fp8_fast_acc=config['use_fast_acc'] + ) + + if len(schedules): + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x): + configs.append(config) + operations.append(o) + + if len(stream_k_schedules): + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]): + configs.append(config) + operations.append(o) + + + return configs, operations + +def filter_manifest_and_write_heuristics_file(manifest, args): + """ + Prune a manifest according to heuristics suggestions from the problems file + + args: + manifest: Cutlass manifest to prune + args: generator.py args, requires: + - args.heuristics_problems_file + - args.heuristics_gpu + - args.heuristics_testlist_file + + returns: + A list of dictionaries, each of which has information about an operation and a problem from the input problems + """ + heuristics_problems = [] + with open(args.heuristics_problems_file, 'r') as f: + heuristics_problems = json.load(f) + gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu + mmh = MatmulHeuristics(gpu=gpu) + if any(('100' in arch) for arch in args.architectures.split(';')): + mmh.set_cta_div_n(64) + problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem) + + all_configs_and_operations = [] + operations = [] + for problem in problems_with_configs: + if any('90' in arch for arch in args.architectures.split(';')): + problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs']) + if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')): + problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs']) + + operations += problem_operations + problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'} + with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)] + all_configs_and_operations += with_problem_size + + for operation in operations: + manifest.add_kernel_filter(f"^{operation.procedural_name()}$") + if not all_configs_and_operations: + raise Exception("No valid configurations generated") + write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file) + return all_configs_and_operations + +def write_profiler_testlist_to_csv(configs_list, outfile_path): + """ + Write a list of configs to a testlist to be consumed by cutlass_profiler + + args: + configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries + outfile_path: Outfile path + + returns: + None + """ + profiler_testlist = configs_list.copy() + for c in profiler_testlist: + for k, v in c.items(): + if isinstance(v, DataType): + c[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + c[k] = ShortLayoutTypeNames[v] + + with open(outfile_path, mode='w', newline='') as ofile: + k_names = profiler_testlist[0].keys() + + writer = csv.DictWriter(ofile, fieldnames=k_names) + writer.writeheader() + writer.writerows(profiler_testlist) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..01a4112a34c87d73a792cce368fede96a9315ac1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py @@ -0,0 +1,175 @@ +################################################################################################# +# +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Providers for kernel selection heuristics +""" + +import sys +import os +import glob +import logging +import ctypes +import functools + + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import DataType, LayoutType +except ImportError: + from library import DataType, LayoutType + +class MatmulHeuristics: + + def __init__(self, gpu = None): + import nvMatmulHeuristics + self.mmh_lib = nvMatmulHeuristics + self.gpu = gpu + + if 'CUTLASS_NVMMH_SO_PATH' in os.environ: + nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH']) + else: + nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx + + self.lh = nvmmhInterfaceEx( + backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"], + flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING, + load_discovery_implicitly=True, + gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None + ) + self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"]) + + def _layout_from_cutlass(self, layouts): + assert(len(layouts)==3) + full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts) + input_layouts = full_layout_str[:2].upper() + lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR") + return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout] + + def _precision_from_cutlass_dtypes(self, dtypes): + dtype_to_cublas = { + DataType.f64: 'D', + DataType.f32: 'S', + DataType.f16: 'H', + DataType.bf16: 'T', + DataType.e4m3: 'Q', + DataType.e5m2: 'R', + DataType.s32: 'I', + DataType.s8: 'B', + } + + dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes + + a_c = dtype_to_cublas[dtype_a] + + if a_c.lower() != 'q': + return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d] + else: + return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d] + + def set_cta_div_n(self, div_n): + cta_n_div_requirement = ctypes.c_int(div_n) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT, + ctypes.byref(cta_n_div_requirement), + ctypes.sizeof(cta_n_div_requirement) + ) + + def set_cta_div_m(self, div_m): + cta_m_div_requirement = ctypes.c_int(div_m) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT, + ctypes.byref(cta_m_div_requirement), + ctypes.sizeof(cta_m_div_requirement) + ) + + def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1): + if use_fast_acc: + disable_fast_acc_for_fp8 = ctypes.c_int(0) + else: + disable_fast_acc_for_fp8 = ctypes.c_int(1) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8, + ctypes.byref(disable_fast_acc_for_fp8), + ctypes.sizeof(disable_fast_acc_for_fp8) + ) + + precision = self._precision_from_cutlass_dtypes(dtypes) + layout = self._layout_from_cutlass(layouts) + + matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count) + configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision) + + ret = [] + for c in configs: + kernel = c['kernel'] + problem = c['problem'] + + r = {} + r['estimated_runtime'] = c['runtime'] + r['cta_tile_m'] = kernel.cta_tile_m + r['cta_tile_n'] = kernel.cta_tile_n + r['cta_tile_k'] = kernel.cta_tile_k + r['instr_tile_m'] = kernel.instr_tile_m + r['instr_tile_n'] = kernel.instr_tile_n + r['instr_tile_k'] = kernel.instr_tile_k + r['warp_tile_m'] = kernel.warp_tile_m + r['warp_tile_n'] = kernel.warp_tile_n + r['warp_tile_k'] = kernel.warp_tile_k + r['cluster_m'] = kernel.cluster_m + r['cluster_n'] = kernel.cluster_n + r['cluster_k'] = 1 + r['layout_a'] = layouts[0] + r['layout_b'] = layouts[1] + r['layout_d'] = layouts[2] + r['dtype_a'] = dtypes[0] + r['dtype_b'] = dtypes[1] + r['dtype_acc'] = dtypes[2] + r['dtype_c'] = dtypes[3] + r['dtype_d'] = dtypes[4] + r['alignment_a'] = align_a + r['alignment_b'] = align_b + r['swizzle_size'] = kernel.swizzle_factor + r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n' + r['split_k_slices'] = kernel.split_k + r['use_fast_acc'] = use_fast_acc + r['voidC'] = voidC + + ret.append(r) + + return ret + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/library.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/library.py new file mode 100644 index 0000000000000000000000000000000000000000..56d22dc4b0705b4813b15b1b09decf53b38f7f37 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/library.py @@ -0,0 +1,1531 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Data types and tags used for emitting CUTLASS C++ kernels +""" + +import enum +import re + +# The following block implements enum.auto() for Python 3.5 variants that don't include it such +# as the default 3.5.2 on Ubuntu 16.04. +# +# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility + +try: + from enum import auto as enum_auto +except ImportError: + __cutlass_library_auto_enum = 0 + def enum_auto() -> int: + global __cutlass_library_auto_enum + i = __cutlass_library_auto_enum + __cutlass_library_auto_enum += 1 + return i + +################################################################################################### + +# +class GeneratorTarget(enum.Enum): + Library = enum_auto() +# +GeneratorTargetNames = { + GeneratorTarget.Library: 'library' +} +# + +################################################################################################### + +# +class DataType(enum.Enum): + void = enum_auto() # primarily used to disable C tensor for epilogues + b1 = enum_auto() + u2 = enum_auto() + u4 = enum_auto() + u8 = enum_auto() + u16 = enum_auto() + u32 = enum_auto() + u64 = enum_auto() + s2 = enum_auto() + s4 = enum_auto() + s8 = enum_auto() + s16 = enum_auto() + s32 = enum_auto() + s64 = enum_auto() + e4m3 = enum_auto() + e5m2 = enum_auto() + f8 = enum_auto() + f6 = enum_auto() + f4 = enum_auto() + e3m2 = enum_auto() + e2m3 = enum_auto() + e2m1 = enum_auto() + ue8m0 = enum_auto() + ue4m3 = enum_auto() + f16 = enum_auto() + bf16 = enum_auto() + f32 = enum_auto() + tf32 = enum_auto() + f64 = enum_auto() + cf16 = enum_auto() + cbf16 = enum_auto() + cf32 = enum_auto() + ctf32 = enum_auto() + cf64 = enum_auto() + cs2 = enum_auto() + cs4 = enum_auto() + cs8 = enum_auto() + cs16 = enum_auto() + cs32 = enum_auto() + cs64 = enum_auto() + cu2 = enum_auto() + cu4 = enum_auto() + cu8 = enum_auto() + cu16 = enum_auto() + cu32 = enum_auto() + cu64 = enum_auto() + invalid = enum_auto() + +# +ShortDataTypeNames = { + DataType.s32: 'i', + DataType.e4m3: 'e4m3', + DataType.e5m2: 'e5m2', + DataType.f16: 'h', + DataType.f32: 's', + DataType.f64: 'd', + DataType.cf32: 'c', + DataType.cf64: 'z', + DataType.f8: 'f8', + DataType.f6: 'f6', + DataType.f4: 'f4', +} + +# +DataTypeNames = { + DataType.void: "void", + DataType.b1: "b1", + DataType.u2: "u2", + DataType.u4: "u4", + DataType.u8: "u8", + DataType.u16: "u16", + DataType.u32: "u32", + DataType.u64: "u64", + DataType.s2: "s2", + DataType.s4: "s4", + DataType.s8: "s8", + DataType.s16: "s16", + DataType.s32: "s32", + DataType.s64: "s64", + DataType.e4m3: 'e4m3', + DataType.e5m2: 'e5m2', + DataType.f8: 'f8', + DataType.f6: 'f6', + DataType.f4: 'f4', + DataType.e2m3: 'e2m3', + DataType.e3m2: 'e3m2', + DataType.e2m1: 'e2m1', + DataType.ue8m0: 'ue8m0', + DataType.ue4m3: 'ue4m3', + DataType.f16: "f16", + DataType.bf16: "bf16", + DataType.f32: "f32", + DataType.tf32: "tf32", + DataType.f64: "f64", + DataType.cf16: "cf16", + DataType.cbf16: "cbf16", + DataType.cf32: "cf32", + DataType.ctf32: "ctf32", + DataType.cf64: "cf64", + DataType.cu2: "cu2", + DataType.cu4: "cu4", + DataType.cu8: "cu8", + DataType.cu16: "cu16", + DataType.cu32: "cu32", + DataType.cu64: "cu64", + DataType.cs2: "cs2", + DataType.cs4: "cs4", + DataType.cs8: "cs8", + DataType.cs16: "cs16", + DataType.cs32: "cs32", + DataType.cs64: "cs64", +} + +DataTypeTag = { + DataType.void: "void", + DataType.b1: "cutlass::uint1b_t", + DataType.u2: "cutlass::uint2b_t", + DataType.u4: "cutlass::uint4b_t", + DataType.u8: "uint8_t", + DataType.u16: "uint16_t", + DataType.u32: "uint32_t", + DataType.u64: "uint64_t", + DataType.s2: "cutlass::int2b_t", + DataType.s4: "cutlass::int4b_t", + DataType.s8: "int8_t", + DataType.s16: "int16_t", + DataType.s32: "int32_t", + DataType.s64: "int64_t", + DataType.e4m3: 'cutlass::float_e4m3_t', + DataType.e5m2: 'cutlass::float_e5m2_t', + DataType.f8: 'cutlass::type_erased_dynamic_float8_t', + DataType.f6: 'cutlass::type_erased_dynamic_float6_t', + DataType.f4: 'cutlass::type_erased_dynamic_float4_t', + DataType.e2m3: 'cutlass::float_e2m3_t', + DataType.e3m2: 'cutlass::float_e3m2_t', + DataType.e2m1: 'cutlass::float_e2m1_t', + DataType.ue8m0: 'cutlass::float_ue8m0_t', + DataType.ue4m3: 'cutlass::float_ue4m3_t', + DataType.f16: "cutlass::half_t", + DataType.bf16: "cutlass::bfloat16_t", + DataType.f32: "float", + DataType.tf32: "cutlass::tfloat32_t", + DataType.f64: "double", + DataType.cf16: "cutlass::complex", + DataType.cbf16: "cutlass::complex", + DataType.cf32: "cutlass::complex", + DataType.ctf32: "cutlass::complex", + DataType.cf64: "cutlass::complex", + DataType.cu2: "cutlass::complex", + DataType.cu4: "cutlass::complex", + DataType.cu8: "cutlass::complex", + DataType.cu16: "cutlass::complex", + DataType.cu32: "cutlass::complex", + DataType.cu64: "cutlass::complex", + DataType.cs2: "cutlass::complex", + DataType.cs4: "cutlass::complex", + DataType.cs8: "cutlass::complex", + DataType.cs16: "cutlass::complex", + DataType.cs32: "cutlass::complex", + DataType.cs64: "cutlass::complex", +} + +DataTypeSize = { + DataType.void: 0, + DataType.b1: 1, + DataType.u2: 2, + DataType.u4: 4, + DataType.u8: 8, + DataType.u16: 16, + DataType.u32: 32, + DataType.u64: 64, + DataType.s2: 2, + DataType.s4: 4, + DataType.s8: 8, + DataType.s16: 16, + DataType.s32: 32, + DataType.s64: 64, + DataType.e4m3: 8, + DataType.e5m2: 8, + DataType.f8: 8, + DataType.f6: 6, + DataType.f4: 4, + DataType.e2m3: 6, + DataType.e3m2: 6, + DataType.e2m1: 4, + DataType.ue8m0: 8, + DataType.ue4m3: 8, + DataType.f16: 16, + DataType.bf16: 16, + DataType.f32: 32, + DataType.tf32: 32, + DataType.f64: 64, + DataType.cf16: 32, + DataType.cbf16: 32, + DataType.cf32: 64, + DataType.ctf32: 32, + DataType.cf64: 128, + DataType.cu2: 4, + DataType.cu4: 8, + DataType.cu8: 16, + DataType.cu16: 32, + DataType.cu32: 64, + DataType.cu64: 128, + DataType.cs2: 4, + DataType.cs4: 8, + DataType.cs8: 16, + DataType.cs16: 32, + DataType.cs32: 64, + DataType.cs64: 128, +} + +################################################################################################### +# +class BlasMode(enum.Enum): + symmetric = enum_auto() + hermitian = enum_auto() + +# +BlasModeTag = { + BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', + BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', +} + +# +class ComplexTransform(enum.Enum): + none = enum_auto() + conj = enum_auto() + +# +ComplexTransformTag = { + ComplexTransform.none: 'cutlass::ComplexTransform::kNone', + ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', +} + +# Used for cutlass3x complex kernel collective mainloop builder instantiation +ComplexTransformTag3x = { + ComplexTransform.none: 'cute::identity', + ComplexTransform.conj: 'cute::conjugate', +} + +# +RealComplexBijection = [ + (DataType.f16, DataType.cf16), + (DataType.f32, DataType.cf32), + (DataType.f64, DataType.cf64), +] + +# +def is_complex(data_type): + for r, c in RealComplexBijection: + if data_type == c: + return True + return False + +def is_block_scaled(gemm_kind): + return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x) + +def is_blockwise(gemm_kind): + return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x) + +def is_grouped(gemm_kind): + return gemm_kind in (GemmKind.GroupedUniversal3x, + GemmKind.GroupedBlockScaledUniversal3x, GemmKind.GroupedBlockwiseUniversal3x) + +# +def get_complex_from_real(real_type): + for r, c in RealComplexBijection: + if real_type == r: + return c + return DataType.invalid + +# +def get_real_from_complex(complex_type): + for r, c in RealComplexBijection: + if complex_type == c: + return r + return DataType.invalid + +# TMA requires an alignment of 128 bits for all data types +def get_tma_alignment(data_type): + if data_type == DataType.void: + return 0 + elif DataTypeSize[data_type] == 6: + return 128 # 96B alignment for 16U6 format + else: + return 128 // DataTypeSize[data_type] + +# +class ComplexMultiplyOp(enum.Enum): + multiply_add = enum_auto() + gaussian = enum_auto() + +################################################################################################### + +# +class MathOperation(enum.Enum): + multiply_add = enum_auto() + multiply_add_saturate = enum_auto() + multiply_add_mixed_input_upcast = enum_auto() + xor_popc = enum_auto() + and_popc = enum_auto() + multiply_add_fast_bf16 = enum_auto() + multiply_add_fast_f16 = enum_auto() + multiply_add_fast_f32 = enum_auto() + multiply_add_complex_fast_f32 = enum_auto() + multiply_add_complex = enum_auto() + multiply_add_complex_gaussian = enum_auto() + multiply_add_fast_accum = enum_auto() + +# +MathOperationTag = { + MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', + MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', + MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast', + MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', + MathOperation.and_popc: 'cutlass::arch::OpAndPopc', + MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', + MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', + MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', + MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', + MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', + MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', + MathOperation.multiply_add_fast_accum: 'cutlass::arch::OpMultiplyAddFastAccum', +} + +################################################################################################### + +# +class LayoutType(enum.Enum): + ColumnMajor = enum_auto() + RowMajor = enum_auto() + ColumnMajorInterleaved2 = enum_auto() + RowMajorInterleaved2 = enum_auto() + ColumnMajorInterleaved32 = enum_auto() + RowMajorInterleaved32 = enum_auto() + ColumnMajorInterleaved64 = enum_auto() + RowMajorInterleaved64 = enum_auto() + TensorNWC = enum_auto() + TensorNHWC = enum_auto() + TensorNDHWC = enum_auto() + TensorNCHW = enum_auto() + TensorNGHWC = enum_auto() + TensorNC32HW32 = enum_auto() + TensorNC64HW64 = enum_auto() + TensorC32RSK32 = enum_auto() + TensorC64RSK64 = enum_auto() + TensorKCS = enum_auto() + TensorKCSR = enum_auto() + TensorKCSRT = enum_auto() + +# +LayoutTag = { + LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', + LayoutType.RowMajor: 'cutlass::layout::RowMajor', + LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', + LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', + LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', + LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', + LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', + LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', + LayoutType.TensorNWC: 'cutlass::layout::TensorNWC', + LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', + LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', + LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', + LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', + LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', + LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', + LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', + LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', + LayoutType.TensorKCS: 'cutlass::layout::TensorKCS', + LayoutType.TensorKCSR: 'cutlass::layout::TensorKCSR', + LayoutType.TensorKCSRT: 'cutlass::layout::TensorKCSRT' +} + +# +TransposedLayout = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor, + LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, + LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, + LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, + LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, + LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, + LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, + LayoutType.TensorNHWC: LayoutType.TensorNHWC +} + +# +ShortLayoutTypeNames = { + LayoutType.ColumnMajor: 'n', + LayoutType.ColumnMajorInterleaved2: 'n2', + LayoutType.ColumnMajorInterleaved32: 'n32', + LayoutType.ColumnMajorInterleaved64: 'n64', + LayoutType.RowMajor: 't', + LayoutType.RowMajorInterleaved2: 't2', + LayoutType.RowMajorInterleaved32: 't32', + LayoutType.RowMajorInterleaved64: 't64', + LayoutType.TensorNWC: 'nwc', + LayoutType.TensorNHWC: 'nhwc', + LayoutType.TensorNDHWC: 'ndhwc', + LayoutType.TensorNCHW: 'nchw', + LayoutType.TensorNGHWC: 'nghwc', + LayoutType.TensorNC32HW32: 'nc32hw32', + LayoutType.TensorNC64HW64: 'nc64hw64', + LayoutType.TensorC32RSK32: 'c32rsk32', + LayoutType.TensorC64RSK64: 'c64rsk64', + LayoutType.TensorKCS: 'kcs', + LayoutType.TensorKCSR: 'kcsr', + LayoutType.TensorKCSRT: 'kcsrt' +} + +# +ShortComplexLayoutNames = { + (LayoutType.ColumnMajor, ComplexTransform.none): 'n', + (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', + (LayoutType.RowMajor, ComplexTransform.none): 't', + (LayoutType.RowMajor, ComplexTransform.conj): 'h' +} + +################################################################################################### +class KernelScheduleType(enum.Enum): + ScheduleAuto = enum_auto() + Multistage = enum_auto() + CpAsyncWarpSpecialized = enum_auto() + CpAsyncWarpSpecializedPingpong = enum_auto() + CpAsyncWarpSpecializedCooperative = enum_auto() + Tma = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() + TmaWarpSpecializedFP8FastAccum = enum_auto() + TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() + TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() + ImplicitTmaWarpSpecializedSm90 = enum_auto() + PtrArrayTmaWarpSpecializedCooperative = enum_auto() + PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() + PtrArrayTmaWarpSpecializedPingpong = enum_auto() + PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto() + + BlockwiseTmaWarpSpecializedCooperative = enum_auto() + PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto() + BlockwiseTmaWarpSpecializedPingpong = enum_auto() + PtrArrayBlockwiseTmaWarpSpecializedPingpong = enum_auto() + + TmaWarpSpecialized1SmSm100 = enum_auto() + TmaWarpSpecialized2SmSm100 = enum_auto() + ImplicitTmaWarpSpecialized1SmSm100 = enum_auto() + ImplicitTmaWarpSpecialized2SmSm100 = enum_auto() + + PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto() + + PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto() + PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto() + PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto() + PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto() + PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() + + SparseTmaWarpSpecialized1SmSm100 = enum_auto() + SparseTmaWarpSpecialized2SmSm100 = enum_auto() + + BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto() + BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto() + Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() + Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() + + BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() + BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() + + PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() + + + Mxf4TmaWarpSpecialized1SmSm100 = enum_auto() + Mxf4TmaWarpSpecialized2SmSm100 = enum_auto() + Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() + Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() + + # FP4 Ultra + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + + Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto() + Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto() + Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto() + Nvf4TmaWarpSpecializedPingpongSm120 = enum_auto() + Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto() + Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto() + + F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto() + + BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto() + BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto() + +KernelScheduleTag = { + KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', + KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage', + KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized', + KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong', + KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative', + KernelScheduleType.Tma: 'cutlass::gemm::KernelTma', + KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized', + KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong', + KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative', + KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum', + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum', + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', + KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8Blockwise', + + KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100', + KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100', + + KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized1SmSm100', + KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100', + + KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100', + KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100', + + KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100', + KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100', + + KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100', + KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100', + + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100', + + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100', + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100', + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100', + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', + + # FP4 Ultra + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum', + + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100", + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100", + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100", + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100", + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100", + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100", + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100", + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100", + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120', + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120', + KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120', + KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongNvf4Sm120', + KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120', + KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120', + + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120', +} + +# +KernelScheduleSuffixes = { + KernelScheduleType.ScheduleAuto: '', + KernelScheduleType.Multistage: '_cpasync', + KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized', + KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong', + KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative', + KernelScheduleType.Tma: '_unspecialized', + KernelScheduleType.TmaWarpSpecialized: '_warpspecialized', + KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum', + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', + KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + + KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm', + + KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: '_2sm', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', + + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', + + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm', + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', + + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: '_cooperative_q', + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: '_pingpong_q', + KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs16', + KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs16', + KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32', + KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32', + + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: '_pingpong_q' +} + +class EpilogueScheduleType(enum.Enum): + ScheduleAuto = enum_auto() + EpilogueTransposed = enum_auto() + NoSmemWarpSpecialized = enum_auto() + PtrArrayNoSmemWarpSpecialized = enum_auto() + NoSmemWarpSpecialized1Sm = enum_auto() + NoSmemWarpSpecialized2Sm = enum_auto() + FastF32NoSmemWarpSpecialized1Sm = enum_auto() + FastF32NoSmemWarpSpecialized2Sm = enum_auto() + BlockwiseNoSmemWarpSpecialized1Sm = enum_auto() + BlockwiseNoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayNoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayNoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() + TmaWarpSpecialized1Sm = enum_auto() + TmaWarpSpecialized2Sm = enum_auto() + PtrArrayTmaWarpSpecialized1Sm = enum_auto() + PtrArrayTmaWarpSpecialized2Sm = enum_auto() + PtrArrayTmaWarpSpecializedPingpong = enum_auto() + PtrArrayTmaWarpSpecializedCooperative = enum_auto() + +# +EpilogueScheduleTag = { + EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto', + EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed', + EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized', + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm', + EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', + EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', + EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm', + EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong', +} + +# +EpilogueScheduleSuffixes = { + EpilogueScheduleType.ScheduleAuto: '', + EpilogueScheduleType.EpilogueTransposed: '', + EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem', + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecialized1Sm: '', + EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma', +} + +class EpilogueFunctor3x(enum.Enum): + LinearCombination = enum_auto() + LinearCombinationBlockScaleFactor = enum_auto() + +# +EpilogueFunctor3xTag = { + EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination', + EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor', +} + +# TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type) +def is_tma_epilogue(epilogue_schedule_type): + return epilogue_schedule_type in [ + EpilogueScheduleType.ScheduleAuto, + EpilogueScheduleType.TmaWarpSpecialized, + EpilogueScheduleType.TmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecialized1Sm, + EpilogueScheduleType.TmaWarpSpecialized2Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, + ] + +def to_grouped_schedule(schedule, grouped): + if not grouped: + return schedule + + group_schedule_map = { + # SM90 + KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative, + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative, + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong, + KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong, + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum, + EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, + EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, + EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized, + # SM100 + KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100, + KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100, + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100, + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100, + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100, + KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100, + KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100, + EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm, + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm, + # SM103 + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, + } + + return group_schedule_map[schedule] + +class TileSchedulerType(enum.Enum): + Default = enum_auto() + Persistent = enum_auto() + StreamK = enum_auto() +# +TileSchedulerTag = { + TileSchedulerType.Default: 'void', + TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler', + TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler', +} + +# +TileSchedulerSuffixes = { + TileSchedulerType.Default: '', + TileSchedulerType.Persistent: '', + TileSchedulerType.StreamK: '_stream_k', +} + +################################################################################################### + +# +class SideMode(enum.Enum): + Left = enum_auto() + Right = enum_auto() + +# +SideModeTag = { + SideMode.Left: 'cutlass::SideMode::kLeft', + SideMode.Right: 'cutlass::SideMode::kRight' +} + +# +ShortSideModeNames = { + SideMode.Left: 'ls', + SideMode.Right: 'rs' +} + +################################################################################################### + +# +class FillMode(enum.Enum): + Lower = enum_auto() + Upper = enum_auto() + +# +FillModeTag = { + FillMode.Lower: 'cutlass::FillMode::kLower', + FillMode.Upper: 'cutlass::FillMode::kUpper' +} + +# +ShortFillModeNames = { + FillMode.Lower: 'l', + FillMode.Upper: 'u' +} + +################################################################################################### + +# +class DiagType(enum.Enum): + NonUnit = enum_auto() + Unit = enum_auto() + +# +DiagTypeTag = { + DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', + DiagType.Unit: 'cutlass::DiagType::kUnit' +} + +# +ShortDiagTypeNames = { + DiagType.NonUnit: 'nu', + DiagType.Unit: 'un' +} + +################################################################################################### + +# +class OpcodeClass(enum.Enum): + Simt = enum_auto() + TensorOp = enum_auto() + WmmaTensorOp = enum_auto() + SparseTensorOp = enum_auto() + BlockScaledTensorOp = enum_auto() + + +OpcodeClassNames = { + OpcodeClass.Simt: 'simt', + OpcodeClass.TensorOp: 'tensorop', + OpcodeClass.WmmaTensorOp: 'wmma_tensorop', + OpcodeClass.SparseTensorOp: 'sptensorop', + OpcodeClass.BlockScaledTensorOp: 'bstensorop' +} + +OpcodeClassTag = { + OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', + OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', + OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', + OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp', + OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp' +} + +################################################################################################### + +# +class OperationKind(enum.Enum): + Gemm = enum_auto() + RankK = enum_auto() + Rank2K = enum_auto() + Trmm = enum_auto() + Symm = enum_auto() + Conv2d = enum_auto() + Conv3d = enum_auto() + +# +OperationKindNames = { + OperationKind.Gemm: 'gemm' + , OperationKind.RankK: 'rank_k' + , OperationKind.Rank2K: 'rank_2k' + , OperationKind.Trmm: 'trmm' + , OperationKind.Symm: 'symm' + , OperationKind.Conv2d: 'conv2d' + , OperationKind.Conv3d: 'conv3d' +} + +# +class Target(enum.Enum): + library = enum_auto() +# +ArchitectureNames = { + 50: 'maxwell', + 60: 'pascal', + 61: 'pascal', + 70: 'volta', + 75: 'turing', + 80: 'ampere', + 89: 'ada', + 90: 'hopper' +} + +# +SharedMemPerCC = { + 70: 96, # 96KB of SMEM + 72: 96, # 96KB of SMEM + 75: 64, # 64KB of SMEM + 80: 163, # 163KB of SMEM - 1KB reserved for the driver + 86: 99, # 99KB of SMEM - 1KB reserved for the driver + 87: 163, # 163KB of SMEM - 1KB reserved for the driver + 89: 99, # 99KB of SMEM - 1KB reserved for the driver + 90: 227, # 227KB of SMEM - 1KB reserved for the driver + 100: 227, # 227KB of SMEM - 1KB reserved for the driver +} + +################################################################################################### + +# +def SubstituteTemplate(template, values): + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text + +################################################################################################### + +# +class GemmKind(enum.Enum): + Gemm = enum_auto() + Sparse = enum_auto() + Universal = enum_auto() + Universal3x = enum_auto() + SparseUniversal3x = enum_auto() + PlanarComplex = enum_auto() + PlanarComplexArray = enum_auto() + Grouped = enum_auto() + BlockScaledUniversal3x = enum_auto() + GroupedUniversal3x = enum_auto() + GroupedBlockScaledUniversal3x = enum_auto() + BlockwiseUniversal3x = enum_auto() + GroupedBlockwiseUniversal3x = enum_auto() + +# +GemmKindNames = { + GemmKind.Gemm: "gemm", + GemmKind.Sparse: "spgemm", + GemmKind.Universal: "gemm", + GemmKind.Universal3x: "gemm", + GemmKind.SparseUniversal3x: "spgemm", + GemmKind.PlanarComplex: "gemm_planar_complex", + GemmKind.PlanarComplexArray: "gemm_planar_complex_array", + GemmKind.Grouped: "gemm_grouped", + GemmKind.BlockScaledUniversal3x: "gemm", + GemmKind.GroupedUniversal3x: "gemm_grouped", + GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped", + GemmKind.BlockwiseUniversal3x: "gemm", + GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped" +} + +# +class RankKKind(enum.Enum): + Universal = enum_auto() + +# +RankKKindNames = { + RankKKind.Universal: "rank_k" +} + +# +class TrmmKind(enum.Enum): + Universal = enum_auto() + +# +TrmmKindNames = { + TrmmKind.Universal: "trmm" +} + +# +class SymmKind(enum.Enum): + Universal = enum_auto() + +# +SymmKindNames = { + SymmKind.Universal: "symm" +} + +# +class EpilogueFunctor(enum.Enum): + LinearCombination = enum_auto() + LinearCombinationClamp = enum_auto() + +# +EpilogueFunctorTag = { + EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', + EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', +} + +# +class MixedInputMode(enum.Enum): + ConvertOnly = enum_auto() + ScaleOnly = enum_auto() + ScaleWithZeroPoint = enum_auto() + +# +class SwizzlingFunctor(enum.Enum): + Identity1 = enum_auto() + Identity2 = enum_auto() + Identity4 = enum_auto() + Identity8 = enum_auto() + Horizontal = enum_auto() + StridedDgradIdentity1 = enum_auto() + StridedDgradIdentity4 = enum_auto() + StridedDgradHorizontal = enum_auto() + StreamK = enum_auto() + +# +SwizzlingFunctorTag = { + SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', + SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', + SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', + SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', + SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', + SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', + SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', + SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', + SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK', +} + +# +class GroupScheduleMode(enum.Enum): + Device = enum_auto(), + Host = enum_auto() + +# +GroupScheduleModeTag = { + GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', + GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' +} + +# +ShortGroupScheduleModeNames = { + GroupScheduleMode.Device: 'Device', + GroupScheduleMode.Host: 'Host' +} + +################################################################################################### + +# +class ConvKind(enum.IntEnum): + Fprop = 0 + Dgrad = 1 + Wgrad = 2 + +# +ConvKindTag = { + ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', + ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', + ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' +} + +ConvKindNames = { + ConvKind.Fprop: 'fprop', + ConvKind.Dgrad: 'dgrad', + ConvKind.Wgrad: 'wgrad', +} + +class ConvMode(enum.IntEnum): + CrossCorrelation = 0 + Convolution = 1 + +# +class IteratorAlgorithm(enum.Enum): + Analytic = 0 + Optimized = 1 + FixedChannels = 2 + FewChannels = 3 + FixedStrideDilation = 4 + +# +IteratorAlgorithmTag = { + IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', + IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', + IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', + IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels', + IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation' +} + +IteratorAlgorithmNames = { + IteratorAlgorithm.Analytic: 'analytic', + IteratorAlgorithm.Optimized: 'optimized', + IteratorAlgorithm.FixedChannels: 'fixed_channels', + IteratorAlgorithm.FewChannels: 'few_channels', + IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation' +} + +# +class StrideSupport(enum.Enum): + Strided = 0 + Unity = 1 + Fixed = 2 + +# +StrideSupportTag = { + StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', + StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', + StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed' +} + +StrideSupportNames = { + StrideSupport.Strided: '', + StrideSupport.Unity: 'unity_stride', + StrideSupport.Fixed: 'fixed_stride' +} + +# +class GroupMode(enum.Enum): + NoneGroup = enum_auto() # dense conv (G=1) + SingleGroup = enum_auto() # grouped convolution (single group per CTA) + MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) + Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) + +# +GroupModeTag = { + GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone', + GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup', + GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup', + GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise', +} + +GroupModeNames = { + GroupMode.NoneGroup: '', + GroupMode.SingleGroup: 'single_group', + GroupMode.MultipleGroup: 'multiple_group', + GroupMode.Depthwise: 'depthwise', +} + +DynamicClusterShape = [0, 0, 1] + +################################################################################################### + +# +class MathInstruction: + def __init__(self, + instruction_shape, \ + element_a, element_b, element_accumulator, \ + opcode_class, math_operation = MathOperation.multiply_add \ + , element_scale_factor = None + ): + + self.instruction_shape = instruction_shape + self.element_a = element_a + self.element_b = element_b + self.element_accumulator = element_accumulator + self.opcode_class = opcode_class + self.math_operation = math_operation + self.element_scale_factor = element_scale_factor + +# +class TileDescription: + + def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1], explicit_vector_sizes = None): + self.threadblock_shape = threadblock_shape + self.tile_shape = threadblock_shape + self.stages = stages + self.warp_count = warp_count + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + self.cluster_shape = cluster_shape + self.explicit_vector_sizes = explicit_vector_sizes + + def procedural_name(self): + if self.minimum_compute_capability >= 90: + return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format( + tbm = self.threadblock_shape[0], + tbn = self.threadblock_shape[1], + tbk = self.threadblock_shape[2], + cm = self.cluster_shape[0], + cn = self.cluster_shape[1], + ck = self.cluster_shape[2], + s = self.stages) + else: + return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + +# +class Direct2dConvFixedStrideDilationTileDescription: + def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): + self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] + self.threadblock_output_shape = threadblock_output_shape + self.filter_shape = filter_shape + self.stages = stages + self.warp_count = warp_count + self.stride = stride + self.dilation = dilation + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.threadblock_output_shape[0], + self.threadblock_output_shape[1], + self.threadblock_output_shape[2], + self.threadblock_output_shape[3], + self.stages, + self.filter_shape[0], + self.filter_shape[1]) + # Fixed Strided and dilation + if self.stride != [-1, -1] and self.dilation != [-1, -1]: + str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], + self.stride[1], + self.dilation[0], + self.dilation[1]) + return str_name + +# +class Direct2dConvFixedStrideDilationTileDescription: + def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): + self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] + self.threadblock_output_shape = threadblock_output_shape + self.filter_shape = filter_shape + self.stages = stages + self.warp_count = warp_count + self.stride = stride + self.dilation = dilation + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.threadblock_output_shape[0], + self.threadblock_output_shape[1], + self.threadblock_output_shape[2], + self.threadblock_output_shape[3], + self.stages, + self.filter_shape[0], + self.filter_shape[1]) + # Fixed Strided and dilation + if self.stride != [-1, -1] and self.dilation != [-1, -1]: + str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], + self.stride[1], + self.dilation[0], + self.dilation[1]) + return str_name + +# +class TensorDescription: + def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): + self.element = element + self.layout = layout + self.alignment = alignment + self.complex_transform = complex_transform + +# +class SymmetricTensorDescription: + def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left): + self.element = element + self.layout = layout + self.fill_mode = fill_mode + self.alignment = alignment + self.complex_transform = complex_transform + self.side_mode = side_mode + +# +class TriangularTensorDescription: + def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none): + self.element = element + self.layout = layout + self.side_mode = side_mode + self.fill_mode = fill_mode + self.diag_type = diag_type + self.alignment = alignment + self.complex_transform = complex_transform + +# +def CalculateSmemUsage(operation): + cta_shape = operation.tile_description.threadblock_shape + stages = operation.tile_description.stages + + if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: + # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) + if DataTypeSize[operation.A.element] == 32: + elements_per_8b_md = 2 + elif DataTypeSize[operation.A.element] == 4: + elements_per_8b_md = 8 + else: + elements_per_8b_md = 4 + + smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ + DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ + cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md + else: + # Few BLAS3 operations only have A tensor + data_type_size_a = DataTypeSize[operation.A.element] + data_type_size_b = DataTypeSize[operation.A.element] + if operation.is_mixed_input(): + data_type_size_b = DataTypeSize[operation.B.element] + + smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \ + data_type_size_b * cta_shape[1] * cta_shape[2] // 8 + + smem_usage = smem_per_stage * stages + return (smem_usage >> 10) + + +class GemmUniversalMode(enum.IntEnum): + """ + Types corresponding to GemmUniversalMode + """ + Gemm = 0 + GemmSplitKParallel = 1 + Batched = 2 + Array = 3 + + +class SplitKMode(enum.IntEnum): + """ + Types corresponding to SplitKMode + """ + NoneSplitK = 0 + Serial = 1 + Parallel = 2 diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..5733ef26322794ee650dfa0c8c2b170bd8c6f3e5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py @@ -0,0 +1,868 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for filtering CUTLASS library kernels and emitting library intitialization +and building code +""" + +import enum +import logging +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.gemm_operation import * + from cutlass_library.rank_k_operation import * + from cutlass_library.rank_2k_operation import * + from cutlass_library.trmm_operation import * + from cutlass_library.symm_operation import * + from cutlass_library.conv2d_operation import * + from cutlass_library.conv3d_operation import * +except ImportError: + from library import * + from gemm_operation import * + from rank_k_operation import * + from rank_2k_operation import * + from trmm_operation import * + from symm_operation import * + from conv2d_operation import * + from conv3d_operation import * + +################################################################################################### +_LOGGER = logging.getLogger(__name__) + + +class EmitOperationKindAll: + """ + Emit the OperationKind-level CUTLASS library initialization code. + The code is generated in the {generated_path}/{operation_kind} directory + (e.g., tools/library/generated/gemm in the build directory, + for OperationKind=Gemm), in the all_{operation_kind}_operations.cu file + (e.g., all_gemm_operations.cu for OperationKind=Gemm). + That file declares several functions in namespace cutlass::library. + The functions all have this form, + + void initialize_{configuration_name}(Manifest& manifest); + + The file also _defines_ the following function in that namespace. + + void initialize_all_{operation_kind}_operations(Manifest& manifest); + + That function calls all of the functions declared in this file. + Those functions are defined in subdirectories + (which this class does not create). + """ + + def __init__(self, generated_path, kind, args): + self.generated_path = generated_path + self.kind = kind + self.args = args + + self.header_template =""" +/* + Generated by manifest.py - Do not edit. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.entry_template = """ + +// +// Entry point to construct operations +// +void initialize_all_${operation_name}_operations(Manifest &manifest) { +""" + self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" + self.configuration_template =" initialize_${configuration_name}(manifest);\n" + + self.epilogue_template ="""} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +""" + + # + def __enter__(self): + _LOGGER.debug("*** EmitOperationKindAll::__enter__") + + self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind]) + _LOGGER.debug('*** operation_path (directory to create): ' + + str(self.operation_path)); + os.makedirs(self.operation_path, exist_ok=True) + + self.top_level_path = os.path.join(self.operation_path, f"all_{OperationKindNames[self.kind]}_operations.cu") + _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}") + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.header_template) + + self.source_files = [self.top_level_path,] + + self.configurations = [] + + return self + + # + def emit(self, operations): + _LOGGER.debug('*** EmitOperationKindAll::emit') + _LOGGER.debug(f"*** len(operations): {len(operations)}") + _LOGGER.debug(f"*** min_cc list: {sorted(min_cc for min_cc, _ in operations.items())}") + + for min_cc, configurations in sorted(operations.items()): + _LOGGER.debug(f"*** min_cc={min_cc}") + + for configuration_name, _ in configurations.items(): + _LOGGER.debug(f"*** configuration_name={configuration_name}") + self.configurations.append(configuration_name) + self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} )) + + # + def __exit__(self, exception_type, exception_value, traceback): + _LOGGER.debug("*** EmitOperationKindAll::__exit__") + + self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]})) + + for configuration_name in self.configurations: + self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name})) + + self.top_level_file.write(self.epilogue_template) + self.top_level_file.close() + + +class EmitOperationKindLibrary: + """ + Emit the CUTLASS library initialization code for each OperationKind. + The code is generated in the directory + {generated_path}/{operation_kind}/{min_cc} + (e.g., tools/library/generated/gemm/90 in the build directory, + for min_cc=90 and OperationKind=Gemm), in the file + all_sm{min_cc}_{operation_kind}_operations.cu + (e.g., all_sm90_gemm_operations.cu for min_cc=90 and OperationKind=Gemm). + The min_cc variable here indicates the minimum GPU architecture version + that the things to be initialized require. + For example, min_cc=90 indicates sm90. + + That file declares several functions in namespace cutlass::library. + The functions all have this form, + + void initialize_all_sm{min_cc}_{subclass_name}_{extended_name}_operations(Manifest& manifest); + + where extended_name is operation.extended_name() for all the operations + given to the emit method (which see below). (All operations for a given + configuration_name are guaranteed to have the same extended_name().) + + The file also _defines_ the following function in that namespace. + + void initialize_all_sm{min_cc}__{operation_kind}_operations(Manifest& manifest); + + That function calls all of the functions declared in this file. + Those functions are defined in subdirectories. + The mapping from OperationKind to emitter handles the details + of what happens in each of those subdirectories. + """ + + def __init__(self, generated_path, min_cc, kind, args): + self.generated_path = generated_path + self.min_cc = min_cc + self.kind = kind + self.args = args + self.emitters = { + OperationKind.Gemm: EmitGemmConfigurationLibrary, + OperationKind.Conv2d: EmitConv2dConfigurationLibrary, + OperationKind.Conv3d: EmitConv3dConfigurationLibrary, + OperationKind.RankK: EmitRankKConfigurationLibrary, + OperationKind.Rank2K: EmitRank2KConfigurationLibrary, + OperationKind.Trmm: EmitTrmmConfigurationLibrary, + OperationKind.Symm: EmitSymmConfigurationLibrary + } + + self.header_template =""" +/* + Generated by manifest.py - Do not edit. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + self.entry_template = """ + +// +// Entry point to construct operations +// +void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest) { +""" + self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" + self.configuration_template = " initialize_${configuration_name}(manifest);\n" + self.subclass_call_template = " initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(manifest);\n" + self.subclass_prototype_template = "void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest);\n" + self.epilogue_template ="""} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +""" + + # + def __enter__(self): + _LOGGER.debug("*** EmitOperationKindLibrary::__enter__") + _LOGGER.debug(f"*** generated_path: {str(self.generated_path)}") + _LOGGER.debug(f"*** OperationKindNames[kind]: {OperationKindNames[self.kind]}") + _LOGGER.debug(f"*** min_cc: {self.min_cc}") + + self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind], str(self.min_cc)) + _LOGGER.debug(f"*** operation_path (directory to make): {str(self.operation_path)}") + os.makedirs(self.operation_path) + + self.top_level_path = os.path.join(self.operation_path, f"all_sm{self.min_cc}_{OperationKindNames[self.kind]}_operations.cu") + _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}") + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.header_template) + + self.source_files = {} + + # Each {operation_kind x cc} combination is further decomposed by the instruction + # types used. This dictionary used to track the file handles for the top-level + # files of each subclass + self.subclass_files = {} + + # Configurations in each sub class + self.subclass_configurations = {} + + return self + + # + def emit(self, configuration_name, operations): + _LOGGER.debug("*** EmitOperationKindLibrary::emit") + _LOGGER.debug(f"*** configuration_name: {configuration_name}") + + assert len(operations) > 0 + + # The extended name for all operations of a given configuration_name is guaranteed + # to be the same because extended_name() is used in defining configuration_name. Thus, + # we can safely use the extended_name() of the first operation. + extended_name = operations[0].extended_name() + _LOGGER.debug('*** extended_name (for all ops): ' + extended_name) + + # Create a directory for operations with this subclass if it does not exist + if extended_name not in self.subclass_files: + subclass_path = os.path.join(self.operation_path, extended_name) + _LOGGER.debug(f"*** subclass_path: {str(subclass_path)}") + os.mkdir(subclass_path) + + self.subclass_configurations[extended_name] = [] + + # Open a new top-level file for this sub class + subclass_top_level_path = os.path.join( + subclass_path, f"all_sm{self.min_cc}_{extended_name}_{OperationKindNames[self.kind]}_operations.cu") + _LOGGER.debug('*** subclass_top_level_path (min_cc, extended_name, ' + + 'OperationKind): ' + str(subclass_top_level_path)) + + self.subclass_files[extended_name] = open(subclass_top_level_path, "w") + self.subclass_files[extended_name].write(self.header_template) + + self.source_files[extended_name] = [subclass_top_level_path] + + subclass_dir = os.path.dirname(self.subclass_files[extended_name].name) + _LOGGER.debug('*** subclass_dir: ' + str(subclass_dir)) + + with self.emitters[self.kind](subclass_dir, configuration_name) as configuration_emitter: + for operation in operations: + configuration_emitter.emit(operation) + + _LOGGER.debug('*** configuration_emitter.configuration_path: ' + + str(configuration_emitter.configuration_path)) + self.source_files[extended_name].append(configuration_emitter.configuration_path) + + self.subclass_configurations[extended_name].append(configuration_name) + self.subclass_files[extended_name].write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} )) + + # + def __exit__(self, exception_type, exception_value, traceback): + _LOGGER.debug("*** EmitOperationKindLibrary::__exit__") + for subclass_name, subclass_file in sorted(self.subclass_files.items()): + subclass_cfg = { + 'min_cc': str(self.min_cc), + 'subclass_name': subclass_name, + 'operation_name': OperationKindNames[self.kind] + } + self.top_level_file.write(SubstituteTemplate(self.subclass_prototype_template, subclass_cfg)) + + self.top_level_file.write( + SubstituteTemplate(self.entry_template, { + 'min_cc': str(self.min_cc), + 'subclass_name': '', + 'operation_name': OperationKindNames[self.kind] + })) + + # Finish and close all subclass files + for subclass_name, subclass_file in sorted(self.subclass_files.items()): + subclass_cfg = { + 'min_cc': str(self.min_cc), + 'subclass_name': subclass_name, + 'operation_name': OperationKindNames[self.kind] + } + subclass_file.write(SubstituteTemplate(self.entry_template, subclass_cfg)) + + for configuration in self.subclass_configurations[subclass_name]: + subclass_file.write( + SubstituteTemplate(self.configuration_template, { + 'configuration_name': configuration + })) + + subclass_file.write(self.epilogue_template) + subclass_file.close() + + # Write the call to initialize_all for this subclass to the top-level file + self.top_level_file.write(SubstituteTemplate(self.subclass_call_template, subclass_cfg)) + + self.top_level_file.write(self.epilogue_template) + self.top_level_file.close() + +class EmitInterfaceLibrary: + """ + Emit the topmost-level CUTLASS library initialization code. + The code is generated in the generated_path directory + (e.g., tools/library/generated in the build directory), + in the initialize_all.cpp file. + That file declares several functions in namespace cutlass::library. + The functions all have this form, + + void initialize_all_{operation_kind}_operations(Manifest& manifest); + + where {operation_kind} abbreviates the "kind" of operation + (e.g., gemm for matrix-matrix multiply, conv2d for 2-d convolution, + or trmm for triangular solve with multiple right-hand sides). + The definitions of these functions live in subdirectories. + + The file also _defines_ the following function in that namespace. + + void initialize_all(Manifest& manifest); + + That function first prepares the manifest, and then + calls all of the functions declared in this file. + """ + + def __init__(self, generated_path, operation_count, args): + self.generated_path = generated_path + self.args = args + + self.prototypes = [] + self.fn_calls = [] + self.operation_count = str(operation_count) + + self.top_level_hdr_template = ''' +/* + Generated by manifest.py - Do not edit. +*/ +''' + self.top_level_prologue = ''' + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +namespace cutlass { +\tnamespace library { + +${prototypes} +''' + + self.top_level_initialize_kind = ''' +\t\tvoid initialize_all_${kind}_operations(Manifest &manifest) { +${fn_calls} +\t\t} +''' + + self.top_level_initialize = ''' +\t\tvoid initialize_all(Manifest &manifest) { +\t\t\tmanifest.reserve(${operation_count});\n +${fn_calls} +\t\t} +''' + + self.top_level_suffix = ''' +\t} // namespace library +} // namespace cutlass + +''' + + # + def __enter__(self): + _LOGGER.debug("*** EmitInterfaceLibrary::__enter__") + + self.top_level_path = os.path.join(self.generated_path, 'initialize_all.cpp') + _LOGGER.debug("*** top_level_path: " + str(self.top_level_path)) + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.top_level_hdr_template) + + self.source_files = [self.top_level_path,] + + return self + + # + def emit(self, operation_name): + _LOGGER.debug("*** EmitInterfaceLibrary::emit") + _LOGGER.debug("*** operation_name: " + operation_name) + + self.prototypes.append(SubstituteTemplate( + "\t\tvoid initialize_all_${operation_kind}_operations(Manifest &manifest);", + {'operation_kind': operation_name})) + + self.fn_calls.append(SubstituteTemplate( + "\t\t\tinitialize_all_${operation_kind}_operations(manifest);", + {'operation_kind': operation_name})) + + # + def __exit__(self, exception_type, exception_value, traceback): + _LOGGER.debug("*** EmitInterfaceLibrary::__exit__") + + self.top_level_file.write(SubstituteTemplate(self.top_level_prologue, {'prototypes':"\n".join(self.prototypes)})) + + # Write out initialize_all method + self.top_level_file.write(SubstituteTemplate(self.top_level_initialize, + {'operation_count': self.operation_count, 'fn_calls':"\n".join(self.fn_calls)})) + + self.top_level_file.write(self.top_level_suffix) + self.top_level_file.close() + +################################################################################################### +################################################################################################### + +class Options: + def __init__(self): + pass + +################################################################################################### + +# +class Manifest: + + # + def __init__(self, args = None): + self.operations = {} + self.args = args + self.operation_count = 0 + self.operations_by_name = {} + + self.kernel_filter = '' + self.kernel_filter_list = [] + self.kernel_names = [] + self.operations_enabled = [] + self.selected_kernels = [] + self.ignore_kernel_names = [] + self.exclude_kernel_names = [] + self.compute_capabilities_baseline = [50,] + self.compute_capabilities_feature_set = ['50',] + self.curr_build_dir = '.' + self.filter_by_cc = True + + if self.args: + self.kernel_filter = self.args.kernels + self.curr_build_dir = args.curr_build_dir + + # A common user error is to use commas instead of semicolons. + if ',' in args.architectures: + raise RuntimeError("The list of architectures (CMake option CUTLASS_NVCC_ARCHS) must be semicolon-delimited.\nDon't use commas to separate the architectures; use semicolons.\nYou specified the list as: " + args.architectures) + + self.compute_capabilities_feature_set = args.architectures.split(';') if len(args.architectures) else ['50',] + self.compute_capabilities_baseline = sorted(set(int(arch.split('a')[0].split('f')[0]) for arch in self.compute_capabilities_feature_set)) + + if args.filter_by_cc in ['false', 'False', '0']: + self.filter_by_cc = False + + if args.operations == 'all': + self.operations_enabled = [] + else: + operations_list = [ + OperationKind.Gemm + , OperationKind.Conv2d + , OperationKind.Conv3d + , OperationKind.RankK + , OperationKind.Trmm + , OperationKind.Symm + ] + self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] + + if args.kernels == 'all': + self.kernel_names = [] + else: + self.kernel_names = [x for x in args.kernels.split(',') if x != ''] + + self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != ''] + self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != ''] + + if args.kernel_filter_file is None: + self.kernel_filter_list = [] + else: + self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) + _LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format( + filter_count = len(self.kernel_filter_list), + filter_file = args.kernel_filter_file)) + + self.operation_count = 0 + self.operations_by_name = {} + self.disable_full_archs_compilation = args.disable_full_archs_compilation + self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != '' + self.instantiation_level = 0 + try: + self.instantiation_level = int(args.instantiation_level) + except ValueError: + self.instantiation_level = 0 + + def add_kernel_filter(self, filter_str): + filter_re = re.compile(filter_str) + + self.kernel_filter_list.append(filter_re) + + def get_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992): + # Non-negative integer which determines how many kernels are instantiated. + # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations. + # increasing first digit reduces schedule / mixed type pruning, + # increasing second digit generates more cluster sizes, + # increasing third digit generates more MMA multipliers, + # increasing fourth digit generates more instruction shapes. + + if self.instantiation_level > 0: + return self.instantiation_level + + elif self.is_kernel_filter_set_to_all: + return exhaustive_level + + elif self.kernel_filter == '': + return pruned_level + + else: + return default_level + + + def get_kernel_filters(self, kernelListFile): + if os.path.isfile(kernelListFile): + with open(kernelListFile, 'r') as fileReader: + lines = [line.rstrip() for line in fileReader if not line.startswith("#")] + + lines = [re.compile(line) for line in lines if line] + return lines + else: + return [] + + # + def filter_out_kernels(self, kernel_name, kernel_filter_list): + + for kernel_filter_re in kernel_filter_list: + if kernel_filter_re.search(kernel_name) is not None: + return True + + return False + + + # + def _filter_string_matches(self, filter_string, haystack): + ''' Returns true if all substrings appear in the haystack in order''' + substrings = filter_string.split('*') + for sub in substrings: + idx = haystack.find(sub) + if idx < 0: + return False + haystack = haystack[idx + len(sub):] + return True + + # + def filter(self, operation): + ''' Filtering operations based on various criteria''' + + # filter based on compute capability + enabled = not (self.filter_by_cc) + + for cc in self.compute_capabilities_baseline: + + if cc >= operation.tile_description.minimum_compute_capability and \ + cc <= operation.tile_description.maximum_compute_capability and \ + (cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)): + + enabled = True + break + + if not enabled: + return False + + if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled: + return False + + name = operation.procedural_name() + + # eliminate duplicates + if name in self.operations_by_name.keys(): + return False + + # Filter based on list of valid substrings + if len(self.kernel_names): + enabled = False + + # compare against the include list + for name_substr in self.kernel_names: + if self._filter_string_matches(name_substr, name): + _LOGGER.debug(f"Kernel {name} included due to filter string '{name_substr}'.") + enabled = True + break + else: + _LOGGER.debug(f"Kernel {name} NOT included due to not matching '{name_substr}'.") + + # compare against the exclude list + for name_substr in self.ignore_kernel_names: + if self._filter_string_matches(name_substr, name): + _LOGGER.debug(f"Kernel {name} ignored due to filter string '{name_substr}'.") + enabled = False + break + else: + _LOGGER.debug(f"Kernel {name} NOT ignored due to not matching '{name_substr}'.") + + if len(self.kernel_filter_list) > 0: + if self.filter_out_kernels(name, self.kernel_filter_list): + _LOGGER.debug(f"Kernel {name} matched via kernel filter file.") + enabled = True + else: + _LOGGER.debug(f"Kernel {name} culled due to no match in kernel filter file.") + enabled = False + + # CUTLASS_LIBRARY_IGNORE_KERNELS ("ignore" list) only takes effect + # if CUTLASS_LIBRARY_KERNELS was specified. + # Changing that would break backwards compatibility. + # Thus, CUTLASS has introduced the new CMake option CUTLASS_LIBRARY_EXCLUDE_KERNELS, + # that always takes effect, whether or not CUTLASS_LIBRARY_KERNELS was specified. + for name_substr in self.exclude_kernel_names: + if self._filter_string_matches(name_substr, name): + _LOGGER.debug(f"Kernel {name} excluded due to filter string '{name_substr}'.") + enabled = False + break + else: + _LOGGER.debug(f"Kernel {name} NOT excluded due to not matching '{name_substr}'.") + + # TODO: filter based on compute data type + return enabled + # + + # + def append(self, operation): + ''' + Inserts the operation. + + operation_kind -> configuration_name -> [] + ''' + + if self.filter(operation): + + self.selected_kernels.append(operation.procedural_name()) + + self.operations_by_name[operation.procedural_name()] = operation + + # add the configuration + configuration_name = operation.configuration_name() + + # Split operations by minimum CC + min_cc = operation.arch + + if operation.operation_kind not in self.operations.keys(): + self.operations[operation.operation_kind] = {} + + if min_cc not in self.operations[operation.operation_kind]: + self.operations[operation.operation_kind][min_cc] = {} + + if configuration_name not in self.operations[operation.operation_kind][min_cc].keys(): + self.operations[operation.operation_kind][min_cc][configuration_name] = [] + + self.operations[operation.operation_kind][min_cc][configuration_name].append(operation) + self.operation_count += 1 + else: + _LOGGER.debug("Culled {} from manifest".format(operation.procedural_name())) + # + + def emit_manifest_cmake(self, manifest_path, top_level_path, source_files): + with open(manifest_path, "w") as manifest_file: + + target_text = SubstituteTemplate("""cutlass_target_sources(cutlass_library_objs PRIVATE + """, { }) + manifest_file.write(target_text + '\n\n') + manifest_file.write(" %s\n" % str(top_level_path.replace('\\', '/'))) + generated_path = os.path.join(self.curr_build_dir, 'generated') + for kind in self.operations.keys(): + kind_str = OperationKindNames[kind] + all_kind_file = os.path.join(generated_path, kind_str, f"all_{kind_str}_operations.cu").replace('\\', '/') + manifest_file.write(f" {all_kind_file}\n") + manifest_file.write(')\n\n') + + for kind in self.operations.keys(): + for min_cc in sorted(self.operations[kind].keys()): + for subclass in sorted(source_files[kind][min_cc].keys()): + target_text = SubstituteTemplate("""cutlass_add_cutlass_library( + SUFFIX ${kind}_sm${min_cc}_${subclass} +""", { 'min_cc': str(min_cc), 'kind': OperationKindNames[kind], 'subclass': subclass }) + manifest_file.write(target_text + '\n\n') + + for source_file in source_files[kind][min_cc][subclass]: + manifest_file.write(" %s\n" % str(source_file.replace('\\', '/'))) + + manifest_file.write(")\n") + + if self.disable_full_archs_compilation: + self.emit_disable_full_archs_compilation(manifest_file, source_files) + + def emit_disable_full_archs_compilation(manifest_file, source_files): + def for_hopper(name): + pass + + def for_ampere(name): + return "16816" in name or \ + "16832" in name or \ + "16864" in name or \ + ("1688" in name and "tf32" in name) + + def for_turing(name): + return ("1688" in name and "tf32" not in name) or \ + "8816" in name + + def for_volta(name): + return "884" in name + + def is_cpp(name): + return name.endswith(".cpp") + + def get_src_archs_str_given_requested_cuda_archs(archs, source_file): + intersected_archs = archs & set(self.compute_capabilities_baseline) + if intersected_archs == set(): + raise RuntimeError( + """ + Empty archs set for file {} after taking + the intersection of {} (global requested archs) and + {} (per file requested archs) + """.format(source_file, set(self.compute_capabilities_baseline), archs)) + else: + return " ".join(map(str, intersected_archs)) + + for min_cc in sorted(source_files.keys()): + for source_file in source_files[min_cc]: + if is_cpp(source_file): + continue # skip because source is cpp + elif for_ampere(source_file): + archs_str = get_src_archs_str_given_requested_cuda_archs({80, 87, 90}, source_file) + elif for_turing(source_file): + archs_str = get_src_archs_str_given_requested_cuda_archs({75}, source_file) + elif for_volta(source_file): + archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file) + else: + raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file)) + + manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str)) + + # + def emit(self, target = GeneratorTarget.Library): + + operation_emitters = { + GeneratorTarget.Library: EmitOperationKindLibrary + } + + # Emitters for all operations that fall under a particular kind (e.g., GEMM, Conv2d) + kind_emitters = { + GeneratorTarget.Library: EmitOperationKindAll + } + + interface_emitters = { + GeneratorTarget.Library: EmitInterfaceLibrary + } + + generated_path = os.path.join(self.curr_build_dir, 'generated') + + # create generated/ + if os.path.exists(generated_path): + shutil.rmtree(generated_path) + + os.mkdir(generated_path) + + with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter: + top_level_path = iface_emitter.top_level_path + for operation_kind in self.operations.keys(): + iface_emitter.emit(OperationKindNames[operation_kind]) + + source_files = {} + for kind in self.operations.keys(): + source_files[kind] = {} + for min_cc in self.operations[kind].keys(): + source_files[kind][min_cc] = {} + + for operation_kind, ops in self.operations.items(): + for min_cc, configurations in sorted(ops.items()): + with operation_emitters[target](generated_path, min_cc, operation_kind, self.args) as operation_kind_emitter: + for configuration_name, operations in configurations.items(): + _LOGGER.info(f"Emitting {configuration_name} with {len(operations)} operation{'' if len(operations) == 1 else 's'}.") + operation_kind_emitter.emit(configuration_name, operations) + + for subclass, files in operation_kind_emitter.source_files.items(): + if subclass not in source_files[operation_kind][min_cc]: + source_files[operation_kind][min_cc][subclass] = [] + source_files[operation_kind][min_cc][subclass].extend(operation_kind_emitter.source_files[subclass]) + + # Emit top level all_{gemm, conv2d, ...}_operations.cu files + with kind_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter: + operation_kind_emitter.emit(ops) + + # write the manifest.cmake file containing paths from all targets + manifest_path = os.path.join(generated_path, "manifest.cmake") + + self.emit_manifest_cmake(manifest_path, top_level_path, source_files) + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..29ef056f26f914a9c3c33e13900c33642ad2f1b7 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py @@ -0,0 +1,438 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Rank2K kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a Rank K update operation +# +################################################################################################### + +# +class Rank2KOperation: + # + def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + blas_mode = BlasMode.symmetric): + + self.blas_mode = blas_mode + self.operation_kind = OperationKind.Rank2K + self.arch = arch + self.tile_description = tile_description + self.rank_k_kind = rank_k_kind + # tensor A and B have same data type and layout + self.A = A + self.B = A + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def is_planar_complex(self): + return False + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + operation_name = 'syr2k' if self.blas_mode == BlasMode.symmetric else 'her2k' + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] + ) + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.C.fill_mode]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = max([self.A.alignment, self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'fill_mode': self.fill_mode_name(), + 'alignment': "%d" % self.A.alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitRank2KUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.rank_k_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Rank2K< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.rank_k_complex_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Rank2K< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation}, + ${transform_a}, + ${transform_b}, + ${blas_mode} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + + warp_count = operation.tile_description.warp_count + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'fill_mode': FillModeTag[operation.C.fill_mode], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'blas_mode': BlasModeTag[operation.blas_mode] + } + + rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template + + return SubstituteTemplate(rank_k_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitRank2KConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + RankKKind.Universal: EmitRank2KUniversalInstance, + } + + self.rank_k_kind_wrappers = { + RankKKind.Universal: 'Rank2KOperation', + } + + self.instance_template = { + RankKKind.Universal: """ +${compile_guard_start} + manifest.append(new ${rank_k_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by rank_2k_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "rank_2k_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.rank_k_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..9841952332a170d6f401dbe34a0093540c166fb8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py @@ -0,0 +1,427 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting RankK kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a Rank K update operation +# +################################################################################################### + +# +class RankKOperation: + # + def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + blas_mode = BlasMode.symmetric): + + self.blas_mode = blas_mode + self.operation_kind = OperationKind.RankK + self.arch = arch + self.tile_description = tile_description + self.rank_k_kind = rank_k_kind + self.A = A + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_mixed_input(self): + return False + + # + def is_planar_complex(self): + return False + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + operation_name = 'syrk' if self.blas_mode == BlasMode.symmetric else 'herk' + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] + ) + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.C.fill_mode]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = max([self.A.alignment, self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'fill_mode': self.fill_mode_name(), + 'alignment': "%d" % self.A.alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitRankKUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.rank_k_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::RankK< + ${element_a}, ${layout_a}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.rank_k_complex_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::RankK< + ${element_a}, ${layout_a}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${split_k_serial}, + ${math_operation}, + ${transform_a}, + ${blas_mode} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + + warp_count = operation.tile_description.warp_count + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'fill_mode': FillModeTag[operation.C.fill_mode], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'blas_mode': BlasModeTag[operation.blas_mode] + } + + rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template + + return SubstituteTemplate(rank_k_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitRankKConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + RankKKind.Universal: EmitRankKUniversalInstance, + } + + self.rank_k_kind_wrappers = { + RankKKind.Universal: 'RankKOperation', + } + + self.instance_template = { + RankKKind.Universal: """ +${compile_guard_start} + manifest.append(new ${rank_k_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by rank_k_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "rank_k_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.rank_k_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..32e4376513679f06dc085ead068e258b3d8b5e72 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py @@ -0,0 +1,342 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Valid tcgen05 shapes and cluster sizes for SM100, associated with levels. +These shape and level pairs are defined as dicts, where keys are shapes and values are their +associated levels. If the user input level for that category (tcgen05 shape, cluster +size) is smaller than a shape's associated level, it will be excluded, and otherwise, included. +Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently. +Level 0 is always emitted. +""" + +try: + from .library import DynamicClusterShape +except: + from library import DynamicClusterShape + +SM100_CLUSTER_SHAPES_1SM = { + tuple(DynamicClusterShape) : 0, + # size 1 cluster + (1, 1, 1): 1, + # size 2 cluster + (1, 2, 1): 2, + (2, 1, 1): 5, + # size 4 clusters + (2, 2, 1): 6, + (1, 4, 1): 3, + (4, 1, 1): 6, + # size 8 clusters + (2, 4, 1): 7, + (4, 2, 1): 7, + (1, 8, 1): 8, + (8, 1, 1): 8, + # size 16 cluster + (4, 4, 1): 4, +} + +SM100_CLUSTER_SHAPES_2SM = { + tuple(DynamicClusterShape) : 0, + # size 2 cluster + (2, 1, 1): 1, + # size 4 clusters + (2, 2, 1): 2, + (4, 1, 1): 2, + # size 8 clusters + (2, 4, 1): 3, + (4, 2, 1): 3, + (8, 1, 1): 6, + # size 16 cluster + (4, 4, 1): 4, +} + +# MMA shapes + +# 16b Dense + +SM100_MMA_SHAPES_16b_DENSE_1SM = { + (64, 8, 16): 5, + (64, 16, 16): 2, + (64, 24, 16): 5, + (64, 32, 16): 2, + (64, 40, 16): 5, + (64, 48, 16): 5, + (64, 56, 16): 5, + (64, 64, 16): 2, + (64, 72, 16): 5, + (64, 80, 16): 5, + (64, 88, 16): 5, + (64, 96, 16): 5, + (64, 104, 16): 5, + (64, 112, 16): 5, + (64, 120, 16): 5, + (64, 128, 16): 0, + (64, 136, 16): 5, + (64, 144, 16): 5, + (64, 152, 16): 5, + (64, 160, 16): 5, + (64, 168, 16): 5, + (64, 176, 16): 5, + (64, 184, 16): 5, + (64, 192, 16): 3, + (64, 200, 16): 5, + (64, 208, 16): 5, + (64, 216, 16): 5, + (64, 224, 16): 5, + (64, 232, 16): 5, + (64, 240, 16): 5, + (64, 248, 16): 5, + (64, 256, 16): 3, + + (128, 16, 16): 2, + (128, 32, 16): 2, + (128, 48, 16): 5, + (128, 64, 16): 2, + (128, 80, 16): 5, + (128, 96, 16): 5, + (128, 112, 16): 5, + (128, 128, 16): 0, + (128, 144, 16): 5, + (128, 160, 16): 5, + (128, 176, 16): 5, + (128, 192, 16): 3, + (128, 208, 16): 5, + (128, 224, 16): 5, + (128, 240, 16): 5, + (128, 256, 16): 0, + +} + + +SM100_MMA_SHAPES_16b_DENSE_2SM = { + (128, 32, 16): 2, + (128, 64, 16): 2, + (128, 96, 16): 5, + (128, 128, 16): 0, + (128, 160, 16): 5, + (128, 192, 16): 5, + (128, 224, 16): 5, + (128, 256, 16): 0, + + (256, 32, 16): 2, + (256, 64, 16): 2, + (256, 96, 16): 5, + (256, 128, 16): 0, + (256, 160, 16): 5, + (256, 192, 16): 3, + (256, 224, 16): 5, + (256, 256, 16): 0, +} + +# TF32 Dense + +SM100_MMA_SHAPES_TF32_DENSE_1SM = { + (64, 8, 8): 5, + (64, 16, 8): 2, + (64, 24, 8): 5, + (64, 32, 8): 2, + (64, 40, 8): 5, + (64, 48, 8): 5, + (64, 56, 8): 5, + (64, 64, 8): 1, + (64, 72, 8): 5, + (64, 80, 8): 5, + (64, 88, 8): 5, + (64, 96, 8): 5, + (64, 104, 8): 5, + (64, 112, 8): 5, + (64, 120, 8): 5, + (64, 128, 8): 0, + (64, 136, 8): 5, + (64, 144, 8): 5, + (64, 152, 8): 5, + (64, 160, 8): 5, + (64, 168, 8): 5, + (64, 176, 8): 5, + (64, 184, 8): 5, + (64, 192, 8): 3, + (64, 200, 8): 5, + (64, 208, 8): 5, + (64, 216, 8): 5, + (64, 224, 8): 5, + (64, 232, 8): 5, + (64, 240, 8): 5, + (64, 248, 8): 5, + (64, 256, 8): 3, + + (128, 16, 8): 2, + (128, 32, 8): 2, + (128, 48, 8): 5, + (128, 64, 8): 2, + (128, 80, 8): 5, + (128, 96, 8): 5, + (128, 112, 8): 5, + (128, 128, 8): 0, + (128, 144, 8): 5, + (128, 160, 8): 5, + (128, 176, 8): 5, + (128, 192, 8): 3, + (128, 208, 8): 5, + (128, 224, 8): 5, + (128, 240, 8): 5, + (128, 256, 8): 0, + +} + +SM100_MMA_SHAPES_TF32_DENSE_2SM = { + (128, 32, 8): 2, + (128, 64, 8): 1, + (128, 96, 8): 5, + (128, 128, 8): 0, + (128, 160, 8): 5, + (128, 192, 8): 5, + (128, 224, 8): 5, + (128, 256, 8): 0, + + (256, 32, 8): 2, + (256, 64, 8): 1, + (256, 96, 8): 5, + (256, 128, 8): 0, + (256, 160, 8): 5, + (256, 192, 8): 5, + (256, 224, 8): 5, + (256, 256, 8): 0, +} + +# F8F6F4 +SM100_MMA_SHAPES_F8F6F4_DENSE_1SM = { + (64, 8, 32): 4, + (64, 16, 32): 4, + (64, 24, 32): 5, + (64, 32, 32): 3, + (64, 40, 32): 5, + (64, 48, 32): 5, + (64, 56, 32): 5, + (64, 64, 32): 2, + (64, 72, 32): 5, + (64, 80, 32): 5, + (64, 88, 32): 5, + (64, 96, 32): 5, + (64, 104, 32): 5, + (64, 112, 32): 5, + (64, 120, 32): 5, + (64, 128, 32): 0, + (64, 136, 32): 5, + (64, 144, 32): 5, + (64, 152, 32): 5, + (64, 160, 32): 5, + (64, 168, 32): 5, + (64, 176, 32): 5, + (64, 184, 32): 5, + (64, 192, 32): 5, + (64, 200, 32): 5, + (64, 208, 32): 5, + (64, 216, 32): 5, + (64, 224, 32): 5, + (64, 232, 32): 5, + (64, 240, 32): 5, + (64, 248, 32): 5, + (64, 256, 32): 0, + + (128, 16, 32): 4, + (128, 32, 32): 3, + (128, 48, 32): 5, + (128, 64, 32): 2, + (128, 80, 32): 5, + (128, 96, 32): 5, + (128, 112, 32): 5, + (128, 128, 32): 0, + (128, 144, 32): 5, + (128, 160, 32): 5, + (128, 176, 32): 5, + (128, 192, 32): 5, + (128, 208, 32): 5, + (128, 224, 32): 5, + (128, 240, 32): 5, + (128, 256, 32): 0, + +} + +SM100_MMA_SHAPES_F8F6F4_DENSE_2SM = { + (128, 32, 32): 3, + (128, 64, 32): 2, + (128, 96, 32): 5, + (128, 128, 32): 1, + (128, 160, 32): 5, + (128, 192, 32): 5, + (128, 224, 32): 5, + (128, 256, 32): 1, + + (256, 32, 32): 2, + (256, 64, 32): 2, + (256, 96, 32): 5, + (256, 128, 32): 0, + (256, 160, 32): 5, + (256, 192, 32): 5, + (256, 224, 32): 5, + (256, 256, 32): 0, +} + +# MXF8F6F4 +SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM = { + (128, 64, 32): 1, + (128, 128, 32): 0, + (128, 192, 32): 1, + (128, 256, 32): 0, +} + + +SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM = { + (256, 64, 32): 1, + (256, 128, 32): 0, + (256, 192, 32): 1, + (256, 256, 32): 0, + + +} + +# MXF4NVF4 +SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM = { + (128, 64, 64): 1, + (128, 128, 64): 0, + (128, 192, 64): 1, + (128, 256, 64): 0, +} + +SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM = { + # Multiples of 16 for N + (256, 64, 64): 1, + (256, 128, 64): 0, + (256, 192, 64): 1, + (256, 256, 64): 0, + +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf24fe7f528020be4dcfc6ac41cfe949dd63be5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py @@ -0,0 +1,661 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library SM100 kernels +""" + +import argparse +import enum +from itertools import product +import math +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Optional, Sequence, Tuple, List, Union, Callable + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +#### Step 0: define levels + +# One integer level controls multiple "generators" and how many +# combinations they generate. That is the "global" level. +# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and +# anything that is eventually involved in the Cartesian product +# which yields our kernel configurations. +# For simplicity, each generator defines their own levels, +# starting from 0. As a rule we assume 10 or fewer levels, making +# their level a digit. +# The "global" level simply stacks these digits and represents them +# as a single integer. +# +# For example, level 500 indicates cluster sizes are at level 5, MMA +# multipliers are at level 0, and WGMMA shapes are at level 0 as well. +# +# Here we define the global level to generator level mappings. + + +def get_tcgen05_level_from_global_level(global_level: int): + return global_level % 10 + +def get_mma_level_from_global_level(global_level: int): + return (global_level // 10) % 10 + + +def get_cluster_level_from_global_level(global_level: int): + return (global_level // 100) % 10 + + +def get_pruning_level_from_global_level(global_level: int): + return (global_level // 1000) % 10 + + +#### Step 1: generate MMA instruction shapes based on levels + +try: + from .sm100_shapes import * +except: + from sm100_shapes import * + +########### + +def generate_tf32_math_instructions_sm100(level: int): + """ + Generate all TensorOp math instructions for TF32 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + for shape in shapes_2sm: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_16b_math_instructions_sm100(level: int): + """ + Generate all TensorOp math instructions for 16b MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + + for shape in shapes_2sm: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + + +def generate_fp8_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all TensorOp math instructions for FP8 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if enable_compile_time_dtype: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if pruning_level >= 2: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + for shape in shapes_2sm: + if enable_runtime_dtype: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if enable_compile_time_dtype: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if pruning_level >= 2: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_f8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all TensorOp math instructions for FP8 FP6 and FP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_mxf8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all BlockScaledTensorOp math instructions for MXFP8, MXFP6, and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + DataType.e5m2, + DataType.e3m2, + DataType.e2m3, + DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + DataType.e5m2, + DataType.e3m2, + DataType.e2m3, + DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_mxf4nvf4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all BlockScaledTensorOp math instructions for MXFP4 and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + return math_instructions_1sm, math_instructions_2sm + + +def generate_cluster_shapes_sm100(level: int, change_priority_func : Union[Callable, None] = None): + """ + Generate all cluster shapes for SM100 at or above the given level. + + Args: + level: The global level to generate cluster shapes for. + + Returns: + A tuple of two lists of cluster shapes. + The first list contains the cluster shapes for 1SM, and the second list contains the cluster shapes for 2SM. + """ + cluster_level = get_cluster_level_from_global_level(level) + + assert cluster_level >= 4 + + if change_priority_func is not None: + SM100_CLUSTER_SHAPES_1SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_1SM) + SM100_CLUSTER_SHAPES_2SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_2SM) + change_priority_func(SM100_CLUSTER_SHAPES_1SM_CPY, SM100_CLUSTER_SHAPES_2SM_CPY) + shapes_1sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM_CPY.items() if cluster_level >= min_level + ] + shapes_2sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM_CPY.items() if cluster_level >= min_level + ] + + return shapes_1sm, shapes_2sm + + else: + + shapes_1sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM.items() if cluster_level >= min_level + ] + shapes_2sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM.items() if cluster_level >= min_level + ] + + return shapes_1sm, shapes_2sm diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..e14761aae6494f877e6dc6521b30baea0db7509c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py @@ -0,0 +1,212 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Valid WGMMA shapes, MMA multipliers, and cluster sizes for SM90, associated with levels. +These shape and level pairs are defined as dicts, where keys are shapes and values are their +associated levels. If the user input level for that category (MMA multiplier, WGMMA shape, cluster +size) is smaller than a shape's associated level, it will be excluded, and otherwise, included. +Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently. +Level 0 is always emitted. The default behavior in `generator.py` is that level 1 is only emitted +when the `--kernel` argument is non-empty. +""" + +# NOTE: more combinations are possible here. +# Levels [0, 3] exist in order to control exactly what configs are generated in different dtypes. +# The rest are only used in the exhaustive mode (when the corresponding level digit is 9). +# MMA multipliers are multiplied by MMA instruction shapes (WGMMA shapes) to get CTA shapes. +SM90_MMA_MULTIPLIERS = { + (2, 1, 4): 0, + (1, 1, 4): 1, + (4, 1, 4): 2, + (2, 2, 4): 3, + (2, 1, 8): 4, + (4, 1, 8): 4, + (1, 1, 8): 4, + (2, 2, 8): 4, + (2, 1, 16): 5, + (4, 1, 16): 5, + (1, 1, 16): 5, + (2, 2, 16): 5, +} + +# Level 0: only (1, 2, 1) -- fp8 dense gemms in pruned case +# Level 1: clusters with 2 CTAs -- all but fp8 (s8, u8, f16, b16, f32, tf32) dense gemms in pruned case +# Level 2: clusters with 1 or 2 CTAs +# Level 3: clusters with 1, 2, or 4 CTAs +# Level 4: clusters with 1, 2, 4, or 8 CTAs +# Level 5: clusters with 1, 2, 4, 8, or 16 CTAs +SM90_CLUSTER_SIZES = { + (1, 2, 1): 0, + (2, 1, 1): 1, + (1, 1, 1): 2, + (2, 2, 1): 3, + (1, 4, 1): 3, + (4, 1, 1): 3, + (2, 4, 1): 4, + (4, 2, 1): 4, + (1, 8, 1): 4, + (8, 1, 1): 4, + (4, 4, 1): 5, +} + + +# WGMMA shapes +# Level 0: "default" shape only, +# Level 1: additional shapes for the unpruned case (tf32 only) +# Level 2: shapes that are all powers of 2 +# Level 3: all other shapes +SM90_WGMMA_SHAPES_FP16_BF16_DENSE = { + (64, 8, 16): 2, + (64, 16, 16): 2, + (64, 24, 16): 3, + (64, 32, 16): 2, + (64, 40, 16): 3, + (64, 48, 16): 3, + (64, 56, 16): 3, + (64, 64, 16): 2, + (64, 72, 16): 3, + (64, 80, 16): 3, + (64, 88, 16): 3, + (64, 96, 16): 3, + (64, 104, 16): 3, + (64, 112, 16): 3, + (64, 120, 16): 3, + (64, 128, 16): 0, + (64, 136, 16): 3, + (64, 144, 16): 3, + (64, 152, 16): 3, + (64, 160, 16): 3, + (64, 168, 16): 3, + (64, 176, 16): 3, + (64, 184, 16): 3, + (64, 192, 16): 3, + (64, 200, 16): 3, + (64, 208, 16): 3, + (64, 216, 16): 3, + (64, 224, 16): 3, + (64, 232, 16): 3, + (64, 240, 16): 3, + (64, 248, 16): 3, + (64, 256, 16): 1, +} + +SM90_WGMMA_SHAPES_TF32_DENSE = { + (64, 8, 8): 2, + (64, 16, 8): 2, + (64, 24, 8): 3, + (64, 32, 8): 2, + (64, 40, 8): 3, + (64, 48, 8): 3, + (64, 56, 8): 3, + (64, 64, 8): 2, + (64, 72, 8): 3, + (64, 80, 8): 3, + (64, 88, 8): 3, + (64, 96, 8): 3, + (64, 104, 8): 3, + (64, 112, 8): 3, + (64, 120, 8): 3, + (64, 128, 8): 0, + (64, 136, 8): 3, + (64, 144, 8): 3, + (64, 152, 8): 3, + (64, 160, 8): 3, + (64, 168, 8): 3, + (64, 176, 8): 3, + (64, 184, 8): 3, + (64, 192, 8): 3, + (64, 200, 8): 3, + (64, 208, 8): 3, + (64, 216, 8): 3, + (64, 224, 8): 3, + (64, 232, 8): 3, + (64, 240, 8): 3, + (64, 248, 8): 3, + (64, 256, 8): 1, +} + +SM90_WGMMA_SHAPES_FP8_DENSE = { + (64, 8, 32): 2, + (64, 16, 32): 2, + (64, 24, 32): 3, + (64, 32, 32): 2, + (64, 40, 32): 3, + (64, 48, 32): 3, + (64, 56, 32): 3, + (64, 64, 32): 2, + (64, 72, 32): 3, + (64, 80, 32): 3, + (64, 88, 32): 3, + (64, 96, 32): 3, + (64, 104, 32): 3, + (64, 112, 32): 3, + (64, 120, 32): 3, + (64, 128, 32): 0, + (64, 136, 32): 3, + (64, 144, 32): 3, + (64, 152, 32): 3, + (64, 160, 32): 3, + (64, 168, 32): 3, + (64, 176, 32): 3, + (64, 184, 32): 3, + (64, 192, 32): 3, + (64, 200, 32): 3, + (64, 208, 32): 3, + (64, 216, 32): 3, + (64, 224, 32): 3, + (64, 232, 32): 3, + (64, 240, 32): 3, + (64, 248, 32): 3, + (64, 256, 32): 1, +} + +SM90_WGMMA_SHAPES_INT8_DENSE = { + (64, 8, 32): 2, + (64, 16, 32): 2, + (64, 24, 32): 3, + (64, 32, 32): 2, + (64, 48, 32): 3, + (64, 64, 32): 2, + (64, 80, 32): 3, + (64, 96, 32): 3, + (64, 112, 32): 3, + (64, 128, 32): 0, + (64, 144, 32): 3, + (64, 160, 32): 3, + (64, 176, 32): 3, + (64, 192, 32): 3, + (64, 208, 32): 3, + (64, 224, 32): 3, + (64, 240, 32): 3, + (64, 256, 32): 1, +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5fdf14abb85835f71ecfd704a2738f5792af50 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py @@ -0,0 +1,753 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library SM90 kernels +""" + +import argparse +import enum +from itertools import product +import math +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Optional, Sequence, Tuple, List + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +# NOTE: this is a duplicate of CudaToolkitVersionSatisfies in generator.py +def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): + + # by default, use the latest CUDA Toolkit version + cuda_version = [11, 0, 132] + + # Update cuda_version based on parsed string + if semantic_ver_string != '': + for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]): + if i < len(cuda_version): + cuda_version[i] = x + else: + cuda_version.append(x) + return cuda_version >= [major, minor, patch] + +#### Step 0: define levels + +# One integer level controls multiple "generators" and how many +# combinations they generate. That is the "global" level. +# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and +# anything that is eventually involved in the Cartesian product +# which yields our kernel configurations. +# For simplicity, each generator defines their own levels, +# starting from 0. As a rule we assume 10 or fewer levels, making +# their level a digit. +# The "global" level simply stacks these digits and represents them +# as a single integer. +# +# For example, level 500 indicates cluster sizes are at level 5, MMA +# multipliers are at level 0, and WGMMA shapes are at level 0 as well. +# +# Here we define the global level to generator level mappings. + + +def get_wgmma_level_from_global_level(global_level: int): + return global_level % 10 + + +def get_mma_level_from_global_level(global_level: int): + return (global_level // 10) % 10 + + +def get_cluster_level_from_global_level(global_level: int): + return (global_level // 100) % 10 + + +def get_pruning_level_from_global_level(global_level: int): + return (global_level // 1000) % 10 + + +#### Step 1: generate MMA instruction shapes based on levels + +try: + from .sm90_shapes import ( + SM90_MMA_MULTIPLIERS, + SM90_CLUSTER_SIZES, + SM90_WGMMA_SHAPES_TF32_DENSE, + SM90_WGMMA_SHAPES_FP16_BF16_DENSE, + SM90_WGMMA_SHAPES_FP8_DENSE, + SM90_WGMMA_SHAPES_INT8_DENSE, + ) +except: + from sm90_shapes import ( + SM90_MMA_MULTIPLIERS, + SM90_CLUSTER_SIZES, + SM90_WGMMA_SHAPES_TF32_DENSE, + SM90_WGMMA_SHAPES_FP16_BF16_DENSE, + SM90_WGMMA_SHAPES_FP8_DENSE, + SM90_WGMMA_SHAPES_INT8_DENSE, + ) + + +def generate_tf32_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_TF32_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_fp16_bf16_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP16_BF16_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_fp8_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP8_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_int8_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_INT8_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level: int, a_type: DataType, b_type: DataType): + # DataTypeSize are in the unit of bits + a_bytes = DataTypeSize[a_type] // 8 + b_bytes = DataTypeSize[b_type] // 8 + if a_bytes == 4 or b_bytes == 4: + return generate_tf32_math_instruction_shapes_sm90(wgmma_level) + elif a_bytes == 2 or b_bytes == 2: + return generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level) + else: + return generate_fp8_math_instruction_shapes_sm90(wgmma_level) + +########### + +def generate_tf32_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_tf32_math_instruction_shapes_sm90(wgmma_level): + math_instructions.append( + MathInstruction( + math_instruction_shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + return math_instructions + +def generate_fp16_bf16_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level): + math_instructions += [ + MathInstruction( + math_instruction_shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + return math_instructions + +def generate_fp8_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_fp8_math_instruction_shapes_sm90(wgmma_level): + math_instructions += [ + MathInstruction( + math_instruction_shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + return math_instructions + +def generate_mixed_dtype_math_instructions_sm90(level: int, types_of_a_b_acc: List[Tuple[DataType, DataType, DataType]]): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for a_type, b_type, acc_type in types_of_a_b_acc: + math_instruction_shapes = generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level, a_type, b_type) + for math_instruction_shape in math_instruction_shapes: + math_instructions += [ + MathInstruction( + math_instruction_shape, + a_type, b_type, acc_type, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ), + ] + return math_instructions + +def generate_int8_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_int8_math_instruction_shapes_sm90(wgmma_level): + math_instructions += [ + MathInstruction( + math_instruction_shape, + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.u8, DataType.u8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + return math_instructions + +def make_sparse_math_instructions(math_instructions): + sparse_instructions = [] + for inst in math_instructions: + if inst.opcode_class == OpcodeClass.TensorOp: + sparse_instructions.append(MathInstruction( + (inst.instruction_shape[0], inst.instruction_shape[1], inst.instruction_shape[2] * 2), + inst.element_a, inst.element_b, inst.element_accumulator, + OpcodeClass.SparseTensorOp, + inst.math_operation),) + return sparse_instructions + + +#### Step 2: generate tile descriptions from math instruction shapes + +def is_tile_desc_valid(tile_description): + if tile_description.minimum_compute_capability != 90 or tile_description.maximum_compute_capability != 90: + return False + + element_a, element_b, element_accum = ( + tile_description.math_instruction.element_a, + tile_description.math_instruction.element_b, + tile_description.math_instruction.element_accumulator + ) + + cluster_size, cta_shape = ( + tile_description.cluster_shape, + tile_description.threadblock_shape, + ) + grid_size = ( + cta_shape[0] * cluster_size[0] + + cta_shape[1] * cluster_size[1] + + cta_shape[2] * cluster_size[2] + ) + num_ctas_in_cluster = cluster_size[0] * cluster_size[1] * cluster_size[2] + cluster_shape = ( + cluster_size[0] * cta_shape[0], + cluster_size[1] * cta_shape[1], + cluster_size[2] * cta_shape[2] + ) + + FP32_TYPES = [DataType.f32, DataType.tf32] + FP16_TYPES = [DataType.f16, DataType.bf16] + is_fp32 = element_a in FP32_TYPES and element_b in FP32_TYPES + is_fp16 = element_a in FP16_TYPES and element_b in FP16_TYPES + + # Maximum number of CTAs per cluster is 8 for Hopper, but up to 16 is + # allowed for non portable clusters. + if num_ctas_in_cluster > 16 or num_ctas_in_cluster < 1: + return False + + if grid_size < 1: + return False + + # SM90 WGMMA shapes are always 64 across M, therefore + # CTA shape across M must always be a multiple of 64. + if cta_shape[0] < 64 or cta_shape[0] % 64 != 0: + return False + + # The minimum WGMMA shape across N is 8, and increments + # vary across different dtypes, but they're never smaller + # than 8. The minimum CTA shape allowed across N though is 16. + if cta_shape[1] < 16 or cta_shape[1] % 8 != 0: + return False + + # SM90 WGMMA shapes across K are always 8 for 32 bit dense + # operations, 16 for 16 bit, and 32 for 8 bit. In any case, + # the CTA shape across K should be a multiple of 8 and at least + # twice the WGMMA shape across K. + if cta_shape[2] < 16 or cta_shape[2] % 8 != 0: + return False + + # Minimum of 2 stages (very rough heuristic that may filter out valid kernel configs) + if (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 256: + return False + + if is_fp32 and (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 128: + return False + + if is_fp32 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 64: + return False + + if is_fp16 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 128: + return False + + # CTA shape upper bound: <256, 256, 256> + if cta_shape[0] > 256 or cta_shape[1] > 256 or cta_shape[2] > 256: + return False + + return True + +def get_mma_multipliers(level: int): + assert isinstance(level, int) and level >= 0 + mma_level = get_mma_level_from_global_level(level) + return [ + mma_mul for mma_mul, mma_min_level in SM90_MMA_MULTIPLIERS.items() if mma_level >= mma_min_level + ] + +def get_cluster_sizes(level: int, is_aligned: bool): + if not is_aligned: + return [(1, 1, 1)] + assert isinstance(level, int) and level >= 0 + cluster_level = get_cluster_level_from_global_level(level) + return [ + cluster_size for cluster_size, cluster_min_level in SM90_CLUSTER_SIZES.items() if cluster_level >= cluster_min_level + ] + +def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level: int): + tile_descriptions = set() + mma_multipliers, cluster_sizes = get_mma_multipliers(level), get_cluster_sizes(level, is_aligned) + for math_inst, mma_mul, cluster_size in product(math_instructions, mma_multipliers, cluster_sizes): + + # generator can stamp out duplicate kernels, because it doesn't explicitly set instruction + # shape for SM90 kernels, and the 3.X collective API doesn't directly expose them when using + # the auto kernel schedule. + + math_inst_stub = copy.deepcopy(math_inst) + math_inst_stub.instruction_shape = [0, 0, 0] + + tile_desc = TileDescription( + threadblock_shape=[ + math_inst.instruction_shape[0] * mma_mul[0], + math_inst.instruction_shape[1] * mma_mul[1], + math_inst.instruction_shape[2] * mma_mul[2] + ], + stages=0, + warp_count=[4, 1, 1], + math_instruction=math_inst_stub, + min_compute=90, + max_compute=90, + cluster_shape=cluster_size) + # For sparse kernels K-tile is twice as large (due to 2x MMA-K size) + # Reduce it to same size as dense to afford more smem stages + if math_inst.opcode_class == OpcodeClass.SparseTensorOp: + tile_desc.threadblock_shape[2] = tile_desc.threadblock_shape[2] // 2 + if is_tile_desc_valid(tile_desc): + tile_descriptions.add(tile_desc) + + return tile_descriptions + +#### Step 3: map tile description to valid schedules + +def is_tile_desc_compatible_with_cooperative(tile_description): + # Cooperative kernels require a minimum CTA-M of 128 + return tile_description.threadblock_shape[0] % 128 == 0 + + +def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types): + dtype_a, dtype_b, dtype_c, dtype_d, dtype_acc, dtype_epi = ( + data_types["a_type"], + data_types["b_type"], + data_types["c_type"], + data_types["d_type"], + data_types["acc_type"], + data_types["epi_type"] + ) + mn = tile_description.threadblock_shape[0] * tile_description.threadblock_shape[1] + bitsize_c, bitsize_d = DataTypeSize[dtype_c], DataTypeSize[dtype_d] + + shmem_bits_c, shmem_bits_d = bitsize_c * mn, bitsize_d * mn + shmem_bits_total = shmem_bits_c + shmem_bits_d + # Magic number: 2^20 + # Existing logic suggested that tile shape 256x128 (or 128x256) + # would run out of shmem if D is FP32, and source is needed. + # That would be 256 * 128 * 32 == 2^21 (~262 KB), which is over the limit. + # Hopper's max shmem size is 228 KB, and 2^20 ~= 131 KB. + # Since epilogue can't possibly use ALL of the shmem available + # we can just settle on 2^20 bits (~ 131 KB) being the upper bound + # we would allow for epilogue. + # This can be different for non-persistent kernels where epilogue and + # mainloop shmem is shared. + if shmem_bits_total > 2 ** 20: + return False + + return True + + +def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, layout, + instantiation_level, enable_fp8_fast_acc=True, gemm_kind=GemmKind.Universal3x): + # Level 0: prune according to existing generator.py behavior + # Level >= 1: no pruning + level = get_pruning_level_from_global_level(instantiation_level) + schedules = [] + stream_k_schedules = [] + + if not is_tile_desc_valid(tile_description): + return schedules, stream_k_schedules + + FP16_TYPES = [DataType.f16, DataType.bf16] + is_fp16 = data_types["a_type"] in FP16_TYPES and data_types["b_type"] in FP16_TYPES + + FP8_TYPES = [DataType.e4m3, DataType.e5m2] + is_fp8 = data_types["a_type"] in FP8_TYPES and data_types["b_type"] in FP8_TYPES + can_do_fp8_fast_accum = is_fp8 and enable_fp8_fast_acc + + FP32_TYPES = [DataType.f32, DataType.tf32] + is_fp32 = data_types["a_type"] in FP32_TYPES and data_types["b_type"] in FP32_TYPES + requires_transposed_epilogue = is_fp32 and layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.RowMajor + + can_do_cooperative = is_tile_desc_compatible_with_cooperative(tile_description) + can_do_tma_epilogue = is_aligned and not requires_transposed_epilogue and can_tile_desc_use_shmem_in_epilogue(tile_description, data_types) + + default_epilogue = EpilogueScheduleType.NoSmemWarpSpecialized if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed + auto_epilogue = EpilogueScheduleType.ScheduleAuto if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed + + cta_m, cta_n, cta_k = ( + tile_description.threadblock_shape[0], + tile_description.threadblock_shape[1], + tile_description.threadblock_shape[2] + ) + c_type = data_types["c_type"] + d_type = data_types["d_type"] + is_void_c = c_type == DataType.void + + # Filter out invalid kernels + is_nt = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.RowMajor + is_tn = layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.ColumnMajor + is_nn = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.ColumnMajor + + # static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, + # "Copy size must evenly divide SMEM tile."); + if is_fp32 and is_nt and (cta_n % cta_k != 0): + return [], [] + + # static_assert(!TransposeB || (cutlass::bits_to_bytes((size<1>(SmemLayoutB{}) * sizeof_bits::value))) == 128, + # "SmemLayoutB K must be 128bytes to be transposed.") + if is_fp32 and is_nt and cta_k != 32: + return [], [] + + # Static assert failure when instantiating SmemLayoutB + if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0): + return [], [] + + grouped = is_grouped(gemm_kind) + if grouped: + # the following cases are unsupported by grouped GEMM + if not is_aligned: + return [], [] + if requires_transposed_epilogue: + return [], [] + + # Early pruning + if level < 1: + # Don't stamp out FP16/BF16 kernels smaller than or equal to 64x128x64 + if is_fp16 and cta_m <= 64 and cta_n <= 128 and cta_k <= 64: + return [], [] + + # FP8 configs with CTA tile larger than or equal to 256x128x128 limit data types and schedules + is_large_fp8_tile = is_fp8 and cta_m >= 256 and cta_n >= 128 and cta_k >= 128 + if is_large_fp8_tile: + # Only void-C, and only FP8 outputs allowed + if not is_void_c or d_type not in FP8_TYPES: + return [], [] + if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue: + schedules = [] + if is_blockwise(gemm_kind): + schedules.append( + [ + to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + else: + schedules.append( + [ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + schedules.append( + [ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + return schedules, [] + return [], [] + + if is_fp8 and not is_large_fp8_tile: + valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16, DataType.void] + # Prune all configs with fp8 source, and all configs with non-fp8 output + # that have different dtypes for source and output. + if c_type not in valid_dtypes_for_c or (d_type not in FP8_TYPES and c_type != d_type): + return [], [] + + # FP32/TF32 kernels don't stamp out void-C + if is_fp32 and is_void_c: + return [], [] + + # Void-c only makes a difference for TMA epilogues + if is_void_c and not can_do_tma_epilogue: + return [], [] + + # For mixed input data types + a_type_size = DataTypeSize[data_types["a_type"]] + b_type_size = DataTypeSize[data_types["b_type"]] + if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1): + schedules = [] + stream_k_schedules = [] + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized + if a_type_size > b_type_size: + epilogue_schedule = EpilogueScheduleType.EpilogueTransposed + + if not is_blockwise(gemm_kind): + schedules.append([ + KernelScheduleType.TmaWarpSpecialized, + epilogue_schedule + ]) + schedules.append([ + KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule + ]) + if cta_m >= 128: + if a_type_size > b_type_size: + epilogue_schedule = EpilogueScheduleType.EpilogueTransposed + else: + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative + if is_blockwise(gemm_kind): + schedules.append([ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, + epilogue_schedule + ]) + else: + schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule + ]) + return schedules, stream_k_schedules + + if not is_aligned and not is_blockwise(gemm_kind): + schedules = [[KernelScheduleType.CpAsyncWarpSpecialized, + default_epilogue]] + stream_k_schedules = [] + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative: + schedules.append([ + KernelScheduleType.CpAsyncWarpSpecializedCooperative, + default_epilogue + ]) + stream_k_schedules.append([ + KernelScheduleType.CpAsyncWarpSpecializedCooperative, + default_epilogue + ]) + + return schedules, stream_k_schedules + + schedules = [] + # Pruning: emit Void-C and Grouped kernels with persistent kernels only + if (level >= 1 or not is_void_c) and not grouped and not is_blockwise(gemm_kind): + # Pruning: don't stamp out fp8 kernels with auto schedule + if not is_fp8: + schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue]) + schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue]) + stream_k_schedules = [] + + if CudaToolkitVersionSatisfies(cuda_version, 12, 0): + if can_do_tma_epilogue: + assert not requires_transposed_epilogue + # Inconsistency: fp8 pingpong only gets stamped out with fast accum + if (not is_fp8 or level >= 1) and not is_blockwise(gemm_kind): + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped) + ]) + if can_do_fp8_fast_accum: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped) + ]) + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + # Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue + if not is_fp8 or level >= 1: + if not is_blockwise(gemm_kind): + schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) + else: + schedules.append([to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) + + if can_do_fp8_fast_accum: + if not grouped: + schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue]) + schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)]) + + if can_do_cooperative: + if is_blockwise(gemm_kind): + schedules.append([ + to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(default_epilogue, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, + default_epilogue + ]) + else: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(default_epilogue, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + default_epilogue + ]) + if can_do_fp8_fast_accum: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(default_epilogue, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, + default_epilogue + ]) + + # persistent kernels with TMA epilogues + if can_do_tma_epilogue: + assert not requires_transposed_epilogue + if can_do_cooperative: + if is_blockwise(gemm_kind): + schedules.append([ + to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecializedCooperative + ]) + else: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecializedCooperative + ]) + if can_do_fp8_fast_accum: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, + EpilogueScheduleType.TmaWarpSpecializedCooperative + ]) + # Grouped GEMM do not support Stream-K scheduler + if grouped: + return schedules, [] + return schedules, stream_k_schedules + + +#### Misc: helpers + +def generate_data_types_from_math_instruction(math_instruction, element_source = None, element_dest = None, element_epilogue = None): + element_a, element_b = math_instruction.element_a, math_instruction.element_b + element_accumulator = math_instruction.element_accumulator + element_c = element_source or element_accumulator + element_d = element_dest or element_accumulator + element_epilogue = element_epilogue or element_accumulator + data_types = { + "a_type" : element_a, + "b_type" : element_b, + "c_type" : element_c, + "d_type" : element_d, + "acc_type" : element_accumulator, + "epi_type" : element_epilogue + } + return data_types + +def fix_alignments(data_types, layout, alignment_bits = 128): + operand_keys = ["a_type", "b_type", "c_type"] + operands_to_fix = ["c_type"] + new_layout = [] + assert len(layout) == len(operand_keys) + for i, k in enumerate(operand_keys): + assert k in data_types and data_types[k] in DataTypeSize + dtype = data_types[k] + dtype_size_bits = DataTypeSize[dtype] + + layout_type = layout[i][0] + layout_alignment = layout[i][1] + + # Don't modify alignment if dtype's been changed to void + if k in operands_to_fix and dtype_size_bits >= 1: + layout_alignment = alignment_bits // dtype_size_bits + + new_layout.append([layout_type, layout_alignment]) + + return new_layout diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..8661ff798b2e3e0987fdf7e050b6ad2e0f8f3678 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py @@ -0,0 +1,440 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Symm kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a Symm update operation +# +################################################################################################### + +# +class SymmOperation: + # + def __init__(self, symm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + blas_mode = BlasMode.symmetric): + + self.blas_mode = blas_mode + self.operation_kind = OperationKind.Symm + self.arch = arch + self.tile_description = tile_description + self.symm_kind = symm_kind + # tensor A and B have same data type and layout + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def is_planar_complex(self): + return False + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + operation_name = 'symm' if self.blas_mode == BlasMode.symmetric else 'hemm' + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] + ) + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def side_mode_name(self): + return "%s" % (ShortSideModeNames[self.A.side_mode]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.A.fill_mode]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = self.C.alignment + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'side_mode': self.side_mode_name(), + 'fill_mode': self.fill_mode_name(), + 'alignment': "%d" % alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitSymmUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.symm_template = """ +// Symm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Symm< + ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.symm_complex_template = """ +// Symm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Symm< + ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation}, + ${blas_mode} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + + warp_count = operation.tile_description.warp_count + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'side_mode': SideModeTag[operation.A.side_mode], + 'fill_mode': FillModeTag[operation.A.fill_mode], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'blas_mode': BlasModeTag[operation.blas_mode] + } + + symm_template = self.symm_complex_template if operation.is_complex() else self.symm_template + + return SubstituteTemplate(symm_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitSymmConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + SymmKind.Universal: EmitSymmUniversalInstance, + } + + self.symm_kind_wrappers = { + SymmKind.Universal: 'SymmOperation', + } + + self.instance_template = { + SymmKind.Universal: """ +${compile_guard_start} + manifest.append(new ${symm_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by symm_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "symm_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.symm_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.symm_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'symm_kind': self.symm_kind_wrappers[operation.symm_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..46ba360cb615c955d329b390c0ab93d13ed88c7c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py @@ -0,0 +1,447 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Trmm kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a TRMM operation +# +################################################################################################### + +# +class TrmmOperation: + # + def __init__(self, trmm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8): + + self.operation_kind = OperationKind.Trmm + self.arch = arch + self.tile_description = tile_description + self.trmm_kind = trmm_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_planar_complex(self): +# return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray) + return False + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, TrmmKindNames[self.trmm_kind]) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + # + def side_mode_name(self): + return "%s" % (ShortSideModeNames[self.A.side_mode]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.A.fill_mode]) + + # + def diag_type_name(self): + return "%s" % (ShortDiagTypeNames[self.A.diag_type]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = max([self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_${diag_type}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'side_mode': self.side_mode_name(), + 'fill_mode': self.fill_mode_name(), + 'diag_type': self.diag_type_name(), + 'alignment': "%d" % self.C.alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitTrmmUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.trmm_template = """ +// Trmm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Trmm< + ${element_a}, ${layout_a}, + ${side_mode}, ${fill_mode}, ${diag_type}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.trmm_complex_template = """ +// Trmm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Trmm< + ${element_a}, ${layout_a}, + ${side_mode}, ${fill_mode}, ${diag_type}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation}, + ${transform_a} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'side_mode' : SideModeTag[operation.A.side_mode], + 'fill_mode': FillModeTag[operation.A.fill_mode], + 'diag_type' : DiagTypeTag[operation.A.diag_type], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(1), # TRMM A's alignment is always 1 for no padding to work until we make zfill work with variable bytes + 'align_b': str(operation.B.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'transform_a': ComplexTransformTag[operation.A.complex_transform] + } + + trmm_template = self.trmm_complex_template if operation.is_complex() else self.trmm_template + + return SubstituteTemplate(trmm_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitTrmmConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + TrmmKind.Universal: EmitTrmmUniversalInstance, + } + + self.trmm_kind_wrappers = { + TrmmKind.Universal: 'TrmmOperation', + } + + self.instance_template = { + TrmmKind.Universal: """ +${compile_guard_start} + manifest.append(new ${trmm_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by trmm_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "trmm_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.trmm_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.trmm_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'trmm_kind': self.trmm_kind_wrappers[operation.trmm_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/docs_src/source/conf.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/docs_src/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..c396d75a5534493f1ebf90043f2a182eb46abb7f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/docs_src/source/conf.py @@ -0,0 +1,132 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath('../../media/docs')) + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'CUTLASS Python interface' +copyright = '2023, NVIDIA' +author = 'NVIDIA' +release = '3.1.0' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'myst_parser', + 'nbsphinx', + 'nbsphinx_link', + 'sphinx_copybutton', + 'sphinx.ext.autodoc', + 'sphinx.ext.autosectionlabel', + 'sphinx.ext.autosummary', + 'sphinx.ext.coverage', + 'sphinx.ext.extlinks', + 'sphinx.ext.ifconfig', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx_inline_tabs', + ] + +source_suffix = { + '.rst': 'restructuredtext', + '.md': 'markdown', +} + +autodoc_typehints = 'description' + +pygments_style = "sphinx" +pygments_dark_style = "monokai" + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# Ignore errors when converting notebooks +nbsphinx_allow_errors = True + +language = 'en' +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_static_path = ['_static'] + +html_title = "CUTLASS Python" +html_baseurl = 'docs' +html_theme = 'furo' +html_theme_options = { + "light_logo": "cutlass-logo-small.png", + "dark_logo": "cutlass-logo-small.png", + "light_css_variables": { + "color-brand-primary": "#76B900", + "color-brand-content": "#76B900", + }, + "dark_css_variables": { + "color-brand-primary": "#76B900", + "color-brand-content": "#76B900", + }, + "footer_icons": [ + { + "name": "GitHub", + "url": "https://github.com/NVIDIA/cutlass", + "html": """ + + + + """, + "class": "", + }, + ], +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..308a5676b06f00089d1cdfe0fb83b442ca2df36e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/__init__.py @@ -0,0 +1,36 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from .int_tuple import * +from .layout import * +from .swizzle import * +from .typing import * diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/int_tuple.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/int_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..3d722130c52142e68a3bcd54ac708012aeeeaad3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/int_tuple.py @@ -0,0 +1,225 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Functions for manipulating IntTuples +""" + +from functools import reduce +from itertools import chain +from typing import Union +from .typing import Integer + + +def is_int(x): + return isinstance(x, Integer) + + +def is_tuple(x): + return isinstance(x, tuple) + + +def flatten(t): + if is_tuple(t): + if len(t) == 0: + return () + else: + return tuple(i for a in t for i in flatten(a)) + else: + return (t,) + + +def signum(a): + return bool(a > 0) - bool(a < 0) + + +def product(a): + if is_tuple(a): + return reduce(lambda val,elem : val*product(elem), a, 1) + else: + return a + + +def inner_product(a, b): + if is_tuple(a): # tuple tuple + assert len(a) == len(b) + return sum(inner_product(x,y) for x,y in zip(a,b)) + else: # "int" "int" + assert not is_tuple(b) + return a * b + + +def tuple_max(a): + if is_tuple(a): + return max(tuple_max(x) for x in a) + else: + return a + + +def elem_scale(a, b): + if is_tuple(a): + if is_tuple(b): # tuple tuple + assert len(a) == len(b) + return tuple(elem_scale(x,y) for x,y in zip(a,b)) + else: # tuple "int" + assert False # Error + else: + if is_tuple(b): # "int" tuple + return elem_scale(a, product(b)) + else: # "int" "int" + return a * b + + +# Inclusive prefix ceil div with output congruent to input a +def shape_div(a, b): + if is_tuple(a): + if is_tuple(b): # tuple tuple + assert len(a) == len(b) + return tuple(shape_div(x,y) for x,y in zip(a,b)) + else: # tuple "int" + #r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))] + r = [] + for v in a: + r.append(shape_div(v,b)) + b = shape_div(b,product(v)) + return tuple(r) + else: + if is_tuple(b): # "int" tuple + return shape_div(a, product(b)) + else: # "int" "int" + assert a % b == 0 or b % a == 0 + return (a + b - 1) // b + +# Exclusive prefix product with output congruent to input a +def prefix_product(a, init=1): + if is_tuple(a): + if is_tuple(init): # tuple tuple + assert len(a) == len(init) + return tuple(prefix_product(x,i) for x,i in zip(a,init)) + else: # tuple "int" + #r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))] + r = [] + for v in a: + r.append(prefix_product(v,init)) + init = init * product(v) + return tuple(r) + else: + if is_tuple(init): # "int" tuple + assert False # Error + else: # "int" "int" + return init + + +def idx2crd(idx, shape, stride=None): + if stride is None: + stride = prefix_product(shape) + + if is_tuple(idx): + if is_tuple(shape): # tuple tuple tuple + assert len(idx) == len(shape) and len(idx) == len(stride) + return tuple(idx2crd(i, s, d) for i, s, d in zip(idx,shape,stride)) + else: # tuple "int" "int" + assert False # Error + else: + if is_tuple(shape): # "int" tuple tuple + assert len(shape) == len(stride) + return tuple(idx2crd(idx, s, d) for s,d in zip(shape,stride)) + else: # "int" "int" "int" + return (idx // stride) % shape + + +def crd2idx(crd, shape, stride=None): + if stride is None: + stride = prefix_product(shape) + + if is_tuple(crd): + if is_tuple(shape): # tuple tuple tuple + assert len(crd) == len(shape) and len(crd) == len(stride) + return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride)) + else: # tuple "int" "int" + assert False, f"crd={crd}, shape={shape}" # Error + else: + if crd is None: + crd = 0 + + if is_tuple(shape): # "int" tuple tuple + assert len(shape) == len(stride) + result = 0 + for i in range(len(shape)-1): + result += crd2idx(crd % product(shape[i]), shape[i], stride[i]) + crd = crd // product(shape[i]) + return result + crd2idx(crd, shape[-1], stride[-1]) + else: # "int" "int" "int" + return crd * stride + + +# Transform crd into the dst_shape's iteration space +def crd2crd(crd, dst_shape, src_shape=None): + if is_tuple(crd): + if is_tuple(dst_shape): # tuple tuple + assert len(crd) == len(dst_shape) + return tuple(crd2crd(x, y) for x, y in zip(crd,dst_shape)) + else: # tuple "int" + # Ambiguous unless we have src_shape + assert src_shape is not None + return crd2idx(crd, src_shape) + else: + if is_tuple(dst_shape): # "int" tuple + return idx2crd(crd, dst_shape) + else: # "int" "int" + assert crd < dst_shape + return crd + + +# Filter trg according to crd: keep only elements of trg that are paired with None +def slice_(crd: Union[None, tuple, int], + trg: Union[tuple, int]): + if is_tuple(crd): + if is_tuple(trg): # tuple tuple + assert len(crd) == len(trg) + # match C++ behavior of `filter_tuple` using `tuple_cat(...)` + return tuple(chain(*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)]))) + else: + assert False # tuple "int" : Error + elif crd is None: + # match C++ behavior `return cute::tuple{b};` + return (trg,) + else: + return () + + +# Determine if None appears at any of an int_tuples' terminals +def has_none(a: Union[None, tuple, int]): + if is_tuple(a): + return any(has_none(v) for v in a) + else: + return a is None diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/layout.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..7c220eb16dd089c65fdbe6d6929b357ace0a77c1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/layout.py @@ -0,0 +1,367 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Definition of CuTe Layouts and functions to manipulate them +""" + +from itertools import chain +from typing import Union + +from .int_tuple import * + + +class LayoutBase: + pass + + +def is_layout(x): + return isinstance(x, LayoutBase) + + +class Layout(LayoutBase): + def __init__(self, _shape, _stride=None): + self.shape = _shape + if _stride is None: + self.stride = prefix_product(self.shape) + else: + self.stride = _stride + + # operator == + def __eq__(self, other): + return self.shape == other.shape and self.stride == other.stride + + # operator len(L) (len [rank] like tuples) + def __len__(self): + if is_tuple(self.shape): + return len(self.shape) + else: + return 1 + + # operator () (map coord to idx) + def __call__(self, *args): + """ + Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + OR + Slice the layout and return the sublayout (Coord has an Underscore slice op) + + Follow the same behavior of `Layout::operator(Coord const&)` in cute C++ + """ + if has_none(args): + if len(args) == 1: + return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride)) + else: + return Layout(slice_(args, self.shape), slice_(args, self.stride)) + else: + if len(args) == 1: + return crd2idx(args[0], self.shape, self.stride) + else: + return crd2idx(args, self.shape, self.stride) + + # operator [] (get-i like tuples) + def __getitem__(self, i): + if is_tuple(self.shape): + return Layout(self.shape[i], self.stride[i]) + else: + assert i == 0 + return Layout(self.shape, self.stride) + + # size(layout) Size of the domain + def size(self): + return product(self.shape) + + # cosize(layout) Size of the codomain + def cosize(self): + return self(self.size() - 1) + 1 + + # print and str + def __str__(self): + return f"{self.shape}:{self.stride}" + + # error msgs and representation + def __repr__(self): + return f"Layout({self.shape},{self.stride})" + + +# Make Layout from a list of layouts (each layout it's own mode in the result) +def make_layout(*layouts): + if len(layouts) == 1 and not is_layout(layouts[0]): + layouts = layouts[0] + + shape, stride = zip(*((a.shape,a.stride) for a in layouts)) + return Layout(shape, stride) + + +# Size of the domain +def size(layout): + if is_layout(layout): + return layout.size() + return product(layout) + + +# Size of the codomain +def cosize(layout): + return layout.cosize() + + +# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function +def coalesce(layout, profile=None): + if is_tuple(profile): + assert len(layout) >= len(profile) + return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))), + (layout[i] for i in range(len(profile),len(layout))))) + + result_shape = [1] + result_stride = [0] + for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): + # skip their shape-1s + if shape == 1: + continue + # replace our shape-1 with anything + elif result_shape[-1] == 1: + result_shape[-1] = shape + result_stride[-1] = stride + # merge modes if the shape*stride match + elif result_shape[-1] * result_stride[-1] == stride: + result_shape[-1] = result_shape[-1] * shape + # append a new mode + else: + result_shape.append(shape) + result_stride.append(stride) + + if len(result_shape) == 1: + return Layout(result_shape[0], result_stride[0]) + else: + return Layout(tuple(result_shape), tuple(result_stride)) + + +# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them +def filter(layout, profile=None): + if is_tuple(profile): + assert len(layout) >= len(profile) + return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))), + (layout[i] for i in range(len(profile),len(layout))))) + + result_shape = [] + result_stride = [] + for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): + # skip their shape-1s and stride-0s + if not (shape == 1 or stride == 0): + result_shape.append(shape) + result_stride.append(stride) + + if len(result_shape) == 0: + return Layout(1,0) + else: + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout composition +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def composition(layoutA, layoutB): + if layoutB is None: + return layoutA + elif is_int(layoutB): + return composition(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA))))) + elif is_tuple(layoutB.shape): + return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB) + + if layoutB.stride == 0: + return Layout(layoutB.shape, 0) + else: + result_shape = [] + result_stride = [] + rest_shape = layoutB.shape + rest_stride = layoutB.stride + flat_A = coalesce(layoutA) + for (curr_shape, curr_stride) in zip(flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]): + assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0 + new_shape = min(max(1, curr_shape // rest_stride), rest_shape) + + if new_shape != 1: + result_shape.append(new_shape) + result_stride.append(rest_stride * curr_stride) + + rest_shape = rest_shape // new_shape + rest_stride = -(-rest_stride // curr_shape) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride) + + if rest_shape != 1 or len(result_shape) == 0: + result_shape.append(rest_shape) + result_stride.append(rest_stride * flatten(flat_A.stride)[-1]) + + if len(result_shape) == 1: + return Layout(result_shape[0], result_stride[0]) + else: + return Layout(tuple(result_shape), tuple(result_stride)) + + +# Layout complement +def complement(layout, max_idx=1): + if is_int(layout): + return complement(Layout(layout)) + + result_shape = [] + result_stride = [] + current_idx = 1 + + sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape))) + for (stride, shape) in sorted_DS: + if stride == 0 or shape == 1: + continue + + in_bound = current_idx <= shape * stride + # To support symbolic value which can't be evaluated now + assert (type(in_bound) is not bool) or in_bound + + result_shape.append(stride // current_idx) + result_stride.append(current_idx) + current_idx = shape * stride + + result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div + result_stride.append(current_idx) + + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout right inverse +def right_inverse(layout): + if layout is None: + return None + elif is_int(layout): + return Layout(layout) + + result_shape = [] + result_stride = [] + current_idx = 1 + + flat_shape = flatten(layout.shape) + flat_stride = flatten(layout.stride) + sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape))) + for (stride,shape,rstride) in sorted_DSA: + if shape == 1: + continue + if current_idx != stride: + break + + result_shape.append(shape) + result_stride.append(rstride) + current_idx = shape * stride + + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout left inverse +def left_inverse(layout): + if layout is None: + return None + elif is_int(layout): + return Layout(layout) + return right_inverse(make_layout(layout, complement(layout))) + + +# Split a layout by the composition of B and the "rest" +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def logical_divide(layoutA, layoutB): + if layoutB is None: + return layoutA + elif is_int(layoutB): + return logical_divide(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA))))) + + return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA)))) + + +# Reproduce a layoutA over a layoutB +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def logical_product(layoutA, layoutB): + if layoutB is None: + return layoutA + elif is_int(layoutB): + return logical_divide(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA))))) + + return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB)); + + +# Gather the modes from a hierarchical logical_divide or logical_product +def hier_unzip(splitter, layoutA, layoutB): + if layoutB is None: + return make_layout(Layout(1,0), layoutA) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + # A layout with shape ((A,a),(B,b),(C,c)) + split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB))) + # Gather to shape ((A,B,C,...),(a,b,c,...,y,z)) + return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))), + make_layout(chain((split[i][1] for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA)))))) + + # splitter must return a rank-2 layout + return splitter(layoutA, layoutB) + + +# Apply logical divide hierarchically and gather the split modes into two modes +def zipped_divide(layoutA, layoutB): + return hier_unzip(logical_divide, layoutA, layoutB) + + +# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode +def tiled_divide(layoutA, layoutB): + result = zipped_divide(layoutA, layoutB) + return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) + + +# Apply logical product hierarchically and gather the split modes into two modes +def zipped_product(layoutA, layoutB): + return hier_unzip(logical_product, layoutA, layoutB) + + +# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode +def tiled_product(layoutA, layoutB): + result = zipped_product(layoutA, layoutB) + return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) + + +def slice_and_offset(crd: tuple, + layout: Layout): + return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)), + crd2idx(crd, layout.shape, layout.stride)) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/swizzle.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/swizzle.py new file mode 100644 index 0000000000000000000000000000000000000000..308aee0c3838a82c4de53833fb8a36950b30f62d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/swizzle.py @@ -0,0 +1,129 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Methods for layout swizzling +""" + +from .layout import * + + +def shiftr(a, s): + return a >> s if s > 0 else shiftl(a, -s) + + +def shiftl(a, s): + return a << s if s > 0 else shiftr(a, -s) + + +## A generic Swizzle functor + # 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx + # ^--^ Base is the number of least-sig bits to keep constant + # ^-^ ^-^ Bits is the number of bits in the mask + # ^---------^ Shift is the distance to shift the YYY mask + # (pos shifts YYY to the right, neg shifts YYY to the left) + # + # e.g. Given + # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx + # the result is + # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY + # +class Swizzle: + def __init__(self, bits, base, shift): + assert bits >= 0 + assert base >= 0 + assert abs(shift) >= bits + self.bits = bits + self.base = base + self.shift = shift + bit_msk = (1 << bits) - 1 + self.yyy_msk = bit_msk << (base + max(0,shift)) + self.zzz_msk = bit_msk << (base - min(0,shift)) + + # operator () (transform integer) + def __call__(self, offset): + return offset ^ shiftr(offset & self.yyy_msk, self.shift) + + # Size of the domain + def size(self): + return 1 << (self.bits + self.base + abs(self.shift)) + + # Size of the codomain + def cosize(self): + return self.size() + + # print and str + def __str__(self): + return f"SW_{self.bits}_{self.base}_{self.shift}" + + # error msgs and representation + def __repr__(self): + return f"Swizzle({self.bits},{self.base},{self.shift})" + + +class ComposedLayout(LayoutBase): + def __init__(self, layoutB, offset, layoutA): + self.layoutB = layoutB + self.offset = offset + self.layoutA = layoutA + + # operator == + def __eq__(self, other): + return self.layoutB == other.layoutB and self.offset == other.offset and self.layoutA == other.layoutA + + # operator len(L) (len [rank] like tuples) + def __len__(self): + return len(self.layoutA) + + # operator () (map coord to idx) + def __call__(self, *args): + return self.layoutB(self.offset + self.layoutA(*args)) + + # operator [] (get-i like tuples) + def __getitem__(self, i): + return ComposedLayout(self.layoutB, self.offset, self.layoutA[i]) + + # size(layout) Size of the domain + def size(self): + return size(self.layoutA) + + # cosize(layout) Size of the codomain + def cosize(self): + return cosize(self.layoutB) + + # print and str + def __str__(self): + return f"{self.layoutB} o {self.offset} o {self.layoutA}" + + # error msgs and representation + def __repr__(self): + return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/typing.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..834f7e5411f5c2a4e218f9ce8a4f0a229d039710 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/pycute/typing.py @@ -0,0 +1,42 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from abc import ABC + + +class Integer(ABC): + @classmethod + def __subclasshook__(cls, c): + if c in [bool, float]: + return False + + return issubclass(c, int) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/setup_cutlass.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/setup_cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..acc0c46e540735443a4943908852010a80d02187 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/setup_cutlass.py @@ -0,0 +1,74 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + + +import copy +import os +import setuptools +from setuptools import setup +from setuptools.command.build_ext import build_ext + +import setup_pycute +import setup_library + + +# Install cutlass_library package +setup_library.perform_setup() + + +# Install the PyCuTe package +setup_pycute.perform_setup() + + +setup( + name='cutlass_cppgen', + version='4.2.0', + description='CUTLASS Pythonic Interface', + package_dir={'': '.'}, + packages=[ + 'cutlass_cppgen', + 'cutlass_cppgen.emit', + 'cutlass_cppgen.op', + 'cutlass_cppgen.utils', + 'cutlass_cppgen.backend', + 'cutlass_cppgen.backend.utils' + ], + setup_requires=['pybind11'], + install_requires=[ + 'bfloat16', + 'cuda-python>=11.8.0', + 'pybind11', + 'scikit-build', + 'treelib', + 'pydot' + ] +) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/setup_library.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/setup_library.py new file mode 100644 index 0000000000000000000000000000000000000000..c56d6b5556fea2d5e56209b13f5b95e487ca22fb --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/setup_library.py @@ -0,0 +1,46 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from setuptools import setup + + +def perform_setup(): + setup( + name='cutlass_library', + version='4.2.1', + description='CUTLASS library generation scripts', + packages=['cutlass_library'] + ) + + +if __name__ == '__main__': + perform_setup() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/setup_pycute.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/setup_pycute.py new file mode 100644 index 0000000000000000000000000000000000000000..0bad050fcade8b26d33043abbb0f8226be7d816c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/python/setup_pycute.py @@ -0,0 +1,46 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from setuptools import setup + + +def perform_setup(): + setup( + name='pycute', + version='4.2.1', + description='Python implementation of CuTe', + packages=['pycute'], + ) + + +if __name__ == '__main__': + perform_setup() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py new file mode 100644 index 0000000000000000000000000000000000000000..852c0277ebae2fce7e0b083ce2f497a2c828256f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py @@ -0,0 +1,661 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for defining Conv2D problem sizes for testing. + +This file was ported from the C++ version in test/unit/conv/device/conv2d_problems.h +""" + +from cutlass_library import ConvMode + +import cutlass_cppgen +from cutlass_cppgen.shape import Conv2DProblemSize + + +class TestbedConv2dProblemSizes: + def __init__(self, minimum_channel_size: int): + conv2d_default_sizes = self.initialize_conv2d_default_sizes(minimum_channel_size) + conv2d_rigorous_sizes = self.initialize_conv2d_rigorous_sizes(minimum_channel_size) + conv2d_resnet50_sizes = self.initialize_conv2d_resnet50_sizes(1) + conv2d_resnet50_sizes_perf = self.initialize_conv2d_resnet50_sizes(34) + grouped_sizes = self.initialize_conv2d_grouped_sizes() + + # Filter all problems + self.all = [] + for size_list in [conv2d_default_sizes, conv2d_rigorous_sizes, conv2d_resnet50_sizes, conv2d_resnet50_sizes_perf, grouped_sizes]: + for size in size_list: + if (size.C // size.groups) % minimum_channel_size == 0: + self.all.append(size) + + + def initialize_conv2d_default_sizes(self, minimum_channel_size): + # Small input size x stride (1,1) + # C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + + conv2d_default_sizes = [] + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 1, 1, minimum_channel_size, + 8, 1, 1, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 1, 8, minimum_channel_size, + 8, 1, 3, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 7, 8, minimum_channel_size, + 8, 3, 3, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 7, 9, minimum_channel_size, + 8, 4, 4, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 2, 7, 9, minimum_channel_size, + 8, 5, 5, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 7, 9, minimum_channel_size, + 8, 6, 5, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 7, 9, minimum_channel_size, + 8, 6, 6, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 7, 9, minimum_channel_size, + 8, 7, 7, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + ############################################## + # Small input size x stride (2,2) + # C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + ############################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 11, 7, minimum_channel_size, + 8, 1, 1, minimum_channel_size, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 11, 7, minimum_channel_size, + 8, 3, 3, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 13, 11, minimum_channel_size, + 8, 1, 1, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 17, 19, minimum_channel_size, + 16, 2, 2, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 23, 5, minimum_channel_size, + 16, 3, 3, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 13, 17, 8, + 24, 3, 3, 8, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 23, 21, 8, + 24, 3, 3, 8, + 1, 1, + 3, 3, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 20, 24, 8, + 40, 3, 3, 8, + 3, 3, + 3, 3, + 1, 1, + )) + + ########################################## + # Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 15, 19, 160, + 224, 1, 1, 160, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 19, 37, 160, + 224, 3, 3, 160, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 16, 16, 160, + 224, 2, 3, 160, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 23, 21, 128, + 224, 3, 3, 128, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 29, 37, 160, + 224, 5, 5, 160, + 2, 2, + 1, 1, + 1, 1, + )) + + ########################################## + # C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 15, 19, 32 + minimum_channel_size, + 96, 3, 3, 32 + minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 16, 24, 64 + minimum_channel_size, + 96, 3, 3, 64 + minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + ########################################## + # Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 13, 16, 288, + 160, 5, 5, 288, + 2, 2, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 55, 51, 256, + 512, 1, 1, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 71, 80, 32, + 64, 5, 5, 32, + 2, 2, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 224, 224, 8, + 64, 7, 7, 8, + 3, 3, + 2, 2, + 1, 1, + )) + + ########################################## + # Medium input size stride (3, 3), filter (3, 3), non-default padding + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 23, 256, + 512, 3, 3, 256, + 0, 0, + 3, 3, + 1, 1, + )) + + ########################################## + # Medium input size padding > stride, asymmetric filter, padding and striding + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 31, 256, + 512, 3, 3, 256, + 5, 7, + 3, 4, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 35, 256, + 512, 7, 5, 256, + 11, 7, + 3, 5, + 1, 1, + )) + + ########################################## + # Medium input size *mixed* stride (1, 2) and (2, 1), + # filter (3, 3), default padding + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 27, 256, + 512, 3, 3, 256, + 1, 1, + 1, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 27, 256, + 512, 3, 3, 256, + 1, 1, + 2, 1, + 1, 1, + )) + + ######################################/ + # Additional input size + ######################################/ + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 28, 28, 256, + 256, 2, 2, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 32, 32, 16, + 32, 3, 3, 16, + 1, 1, + 6, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 32, 24, 32, 32, + 32, 1, 2, 32, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 4, 2, 3, 256, + 328, 3, 5, 256, + 1, 1, + 1, 1, + 1, 1, + )) + return conv2d_default_sizes + + # Add a few large and rigorous convolution problem sizes + def initialize_conv2d_rigorous_sizes(self, minimum_channel_size): + sizes = [] + if False: + sizes.append(Conv2DProblemSize.from_sizes( + (1, 124, 224, 2 * minimum_channel_size), + (24, 7, 7, 2 * minimum_channel_size), + )) + + sizes.append(Conv2DProblemSize.from_sizes( + (1, 233, 35, minimum_channel_size), + (24, 7, 5, minimum_channel_size), + )) + return sizes + + # Add resent50 layers to unit testing sizes + def initialize_conv2d_resnet50_sizes(self, batch_size): + conv2d_problem_vector = [] + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 64, + 256, 1, 1, 64, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 64, + 64, 1, 1, 64, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 64, + 64, 3, 3, 64, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 256, + 64, 1, 1, 256, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 256, + 512, 1, 1, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 256, + 128, 1, 1, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 128, + 128, 3, 3, 128, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 128, + 512, 1, 1, 128, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 512, + 128, 1, 1, 512, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 512, + 1024, 1, 1, 512, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 512, + 256, 1, 1, 512, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 256, + 256, 3, 3, 256, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 256, + 1024, 1, 1, 256, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 1024, + 256, 1, 1, 1024, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 1024, + 2048, 1, 1, 1024, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 1024, + 512, 1, 1, 1024, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 7, 7, 512, + 512, 3, 3, 512, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 7, 7, 512, + 2048, 1, 1, 512, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 7, 7, 2048, + 512, 1, 1, 2048, + 0, 0, + 1, 1, + 1, 1, + )) + + return conv2d_problem_vector + + def initialize_conv2d_grouped_sizes(self): + threadblock_n = 128 + threadblock_k = 32 + + sizes = [] + ########################################## + # One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 + # One CTA calculates a single group + ########################################## + for cta_per_group_k in range(1, 4): + for groups in range(2, 5): + conv_k = cta_per_group_k * threadblock_n * groups + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 2 * groups, + conv_k, 3, 3, threadblock_k * 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + groups + )) + + # Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k, + threadblock_n * 2, 3, 3, threadblock_k // 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 2 + )) + + sizes.append(Conv2DProblemSize( + 1, 56, 56, 696, + 768, 3, 3, 232, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 3 + )) + sizes.append(Conv2DProblemSize( + 1, 14, 14, 1392, + 1536, 3, 3, 232, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 3 + )) + + ########################################## + # One CTA calculate multiple groups: CTA::N % k_per_group = 0 + ########################################## + + # 2 groups per CTA + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 4, + threadblock_n, 3, 3, threadblock_k * 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 2 + )) + + # 2 groups per CTA and partial gemm_k + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k, + threadblock_n, 3, 3, threadblock_k // 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 2 + )) + + # 4 groups per CTA + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 8, + threadblock_n // 2, 3, 3, threadblock_k * 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 4 + )) + + # 4 groups per CTA and partial gemm_k + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 2, + threadblock_n // 2, 3, 3, threadblock_k // 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 4 + )) + + return sizes diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..f77a0ec831be087bd3badc929eee955f0b37c489 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py @@ -0,0 +1,146 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for Conv2d opreations on SM80 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from conv2d_test_utils import * + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is invalid for SM80 tests.') +class Conv2dSm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +conv_problems = get_conv_problems() + + +# Tests for optimized & analytic +for conv_kind in ["fprop", "wgrad", "dgrad"]: + # F16, simt + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="simt", threadblock_shape=[128, 128, 8], + warp_count=[4, 2, 1], stages=2, instruction_shape=[1, 1, 1]) + # F16, tensor op + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) + # F16, tensor op, analytic iterator + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="analytic") + # F16, tensor op, f32 output + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) + # F16, tensor op, different tile description + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 64, 32], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8]) + # F32, simt + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, + opclass="simt", threadblock_shape=[128, 128, 8], + warp_count=[4, 2, 1], stages=4, instruction_shape=[1, 1, 1]) + # Tf32, tensorop + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, + opclass="tensor_op", threadblock_shape=[128, 128, 16], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8] + ) + # Split-K + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="serial", + split_k_slices=2) + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="parallel", + split_k_slices=5) + # Swizzling functor + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 64, 32], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8], swizzle=4) + +# Tests for few channels and fixed channels +# F16, tensor op, few channels +for c, tb, stage, inst in zip([2, 1], + [[128, 128, 64], [128, 128, 32]], + [3, 2], + [[16, 8, 16], [16, 8, 8]]): + add_test( + Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=tb, + warp_count=[2, 2, 1], stages=stage, instruction_shape=inst, iterator_algorithm="few_channels" + ) +# F16, tensor op, fixed channels +for c in [8, 4, 2]: + add_test( + Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="fixed_channels" + ) + +# Test activations +for activation in ["relu", "leaky_relu"]: + for split_k_mode, split_k_slices in zip(["parallel", "serial", "parallel"], [1, 7, 5]): + add_test( + Conv2dSm80, cc, "fprop", conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode=split_k_mode, + split_k_slices=split_k_slices, activation=activation) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc4542cd5ccf72341f7db3c7947d481b032926d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py @@ -0,0 +1,428 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for Conv2d tests. +""" + +from cutlass_library import SubstituteTemplate +import torch + +import cutlass_cppgen +from cutlass_library import ( + ConvKind, + ConvMode, + DataType, + DataTypeNames, + EpilogueScheduleSuffixes, + KernelScheduleSuffixes, + LayoutType, + OpcodeClassNames, + ShortDataTypeNames, + ShortLayoutTypeNames, + SplitKMode, +) +from cutlass_cppgen.shape import Conv2DProblemSize +from cutlass_cppgen.utils.datatypes import numpy_type, torch_type + +from conv2d_problem_sizes import TestbedConv2dProblemSizes + + +def get_name_conv2d( + arch, + conv_kind, + element, + element_accumulator, + element_output, + opclass, + threadblock_shape, + warp_count, + instruction_shape, + stages, + iterator_algorithm, + swizzle, + split_k_mode, + split_k_slices, + activation +): + """ + Generates a procedural name for a test case for conv2d + + :param arch: compute capability of kernel being generated + :type arch: int + :param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad) + :type conv_kind: str + :param iterator_algorithm: the iterator algorithm applied + :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_c: data type of operand C + :param element_accumulator: data type used in accumulation + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param stride_support: stride support of dgrad + :param alignment: int + :type alignment: int + + :return: str + """ + if iterator_algorithm is None: + iterator_algorithm = "AUTO" + if swizzle is None: + swizzle = 1 + name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}" + + return SubstituteTemplate( + name_format, + { + "arch": str(arch), + "conv_kind": conv_kind, + "iter_alg": iterator_algorithm, + "eA": DataTypeNames[element], + "eB": DataTypeNames[element], + "eC": DataTypeNames[element_output], + "opclass": opclass, + "acc": DataTypeNames[element_accumulator], + "tbM": str(threadblock_shape[0]), + "tbN": str(threadblock_shape[1]), + "tbK": str(threadblock_shape[2]), + "wM": str(threadblock_shape[0] // warp_count[0]), + "wN": str(threadblock_shape[1] // warp_count[1]), + "wK": str(threadblock_shape[2] // warp_count[2]), + "IM": str(instruction_shape[0]), + "IN": str(instruction_shape[1]), + "IK": str(instruction_shape[2]), + "stages": str(stages), + "swizzle": str(swizzle), + "split_k_mode": split_k_mode, + "split_k_slices": str(split_k_slices), + "activation": activation + } + ) + + +def conv2d_few_channel_problemsizes(channels): + problem_sizes = [ + Conv2DProblemSize( + 1, 8, 8, channels, + 16, 3, 3, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 16, 16, channels, + 16, 3, 3, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 16, 16, channels, + 16, 7, 7, channels, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 32, 7, 7, channels, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 64, 7, 7, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 64, 5, 5, channels, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 64, 5, 5, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + ] + + return problem_sizes + + +def validate_problem_size(ps, conv_kind, split_k_slices): + P = (ps.H + 2 * ps.pad_h - ps.dilation_h * (ps.R - 1) - 1) // ps.stride_h + 1 + Q = (ps.W + 2 * ps.pad_w - ps.dilation_w * (ps.S - 1) - 1) // ps.stride_w + 1 + if P != ps.P or Q != ps.Q: + return False + + # Split-K (serial or parallel) is not supported for strided dgrad + if conv_kind == "dgrad" and split_k_slices > 1 and (ps.stride_h > 1 or ps.stride_w > 1): + return False + return True + + +class Conv2dLauncherFrontend: + def __init__(self, plan: cutlass_cppgen.Conv2d, seed: int = 80, backend="numpy"): + self.operation = plan + self.conv_kind = plan.conv_kind + self.seed = seed + self.backend = backend + + self.dtype_A = plan._element_a + self.dtype_B = plan._element_b + self.dtype_C = plan._element_c + self.dtype_acc = plan._element_accumulator + self.layout_A = LayoutType.TensorNHWC + self.layout_B = LayoutType.TensorNHWC + self.layout_C = LayoutType.TensorNHWC + self.layout_D = LayoutType.TensorNHWC + + self.element_compute = DataType.f32 + + if self.dtype_A in [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.bf16]: + self.rand_max = 1 + else: + self.rand_max = 4 + self.activation = plan.activation + + def uniform_init(self, size, dtype): + tensor = torch.ceil( + torch.empty(size=size, dtype=torch_type(dtype), device="cuda").uniform_(-self.rand_max - 0.5, self.rand_max - 0.5) + ).to(memory_format=torch.channels_last) + return tensor + + def reference(self, ps, A, B, C, alpha, beta, activation): + if self.conv_kind == ConvKind.Fprop: + torch_result = alpha * torch.ops.aten.conv2d( + A, + B, + stride=(ps.stride_h, ps.stride_w), + padding=(ps.pad_h, ps.pad_w), + dilation=(ps.dilation_h, ps.dilation_w) + ) + beta * C + elif self.conv_kind == ConvKind.Dgrad: + torch_result = alpha * torch.nn.grad.conv2d_input( + (ps.N, ps.C, ps.H, ps.W), + B, + A, + padding=(ps.pad_h, ps.pad_w), + stride=(ps.stride_h, ps.stride_w) + ) + beta * C + elif self.conv_kind == ConvKind.Wgrad: + torch_result = alpha * torch.nn.grad.conv2d_weight( + B, + (ps.K, ps.C, ps.R, ps.S), + A, + padding=(ps.pad_h, ps.pad_w), + stride=(ps.stride_h, ps.stride_w) + ) + beta * C + else: + raise Exception(f"Conv kind {self.conv_kind} is currently unsupported.") + + if activation == cutlass_cppgen.backend.epilogue.relu: + torch_result = torch.nn.functional.relu(torch_result) + elif activation == cutlass_cppgen.backend.epilogue.leaky_relu: + torch_result = torch.nn.functional.leaky_relu(torch_result, 0.5) + return torch_result + + def run(self, ps, split_k_mode=SplitKMode.Serial, split_k_slices=1, alpha=1.0, beta=0.0): + if self.conv_kind == ConvKind.Fprop: + tensor_A_size = (ps.N, ps.C, ps.H, ps.W) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.K, ps.P, ps.Q) + elif self.conv_kind == ConvKind.Dgrad: + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.C, ps.H, ps.W) + elif self.conv_kind == ConvKind.Wgrad: + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.N, ps.C, ps.H, ps.W) + tensor_C_size = (ps.K, ps.C, ps.R, ps.S) + else: + raise Exception(f"Conv kind {self.conv_kind} is not supported") + + torch.manual_seed(self.seed) + + tensor_A = self.uniform_init(size=tensor_A_size, dtype=self.dtype_A) + tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B) + tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C) + tensor_D = torch.zeros_like(tensor_C).to(memory_format=torch.channels_last) + args = self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D, + stride=(ps.stride_h, ps.stride_w), + padding=(ps.pad_h, ps.pad_w), + dilation=(ps.dilation_h, ps.dilation_w), + alpha=alpha, beta=beta, + split_k=(split_k_mode, split_k_slices)) + + args.sync() + + tensor_D_ref = self.reference(ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation) + + torch.cuda.synchronize() + passed = torch.allclose(tensor_D, tensor_D_ref, atol=2e-06) + + return passed + + +def add_test( + cls, + cc, + conv_kind, + problem_sizes, + element, + element_accumulator, + element_output, + opclass, + threadblock_shape, + warp_count, + instruction_shape, + stages, + iterator_algorithm=None, + swizzle=None, + split_k_mode="serial", + split_k_slices=1, + activation = "identity" +): + """Create a test-running function with the given specification""" + test_name = get_name_conv2d( + cc, conv_kind, element, element_accumulator, + element_output, opclass, threadblock_shape, warp_count, instruction_shape, stages, + iterator_algorithm, swizzle, split_k_mode, split_k_slices, activation) + + def run(self): + # Create the plan + plan = cutlass_cppgen.Conv2d( + kind=conv_kind, + element=element, + element_accumulator=element_accumulator, + element_C=element_output, + element_D=element_output + ) + + # Set the opclass + plan.opclass = opclass + # Set the tile description + td = { + "threadblock_shape": threadblock_shape, + "warp_count": warp_count, + "stages": stages, + "instruction_shape": instruction_shape, + } + + plan.tile_description = td + # Set iterator algorithm + if iterator_algorithm is not None: + plan.iterator_algorithm = iterator_algorithm + # Set swizzling functor + if swizzle is not None: + plan.swizzling_stride = swizzle + + if activation != "identity": + if activation == "leaky_relu": + plan.activation = (cutlass_cppgen.epilogue.leaky_relu, 0.5) + else: + plan.activation = getattr(cutlass_cppgen.epilogue, activation) + + conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch") + + for ps in problem_sizes: + if not validate_problem_size(ps, conv_kind, split_k_slices): + continue + + self.assertTrue(conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 2.0)) + + setattr(cls, test_name, run) + + return run + + +def get_conv_problems(): + # 64: minimum channel size + conv_problems = TestbedConv2dProblemSizes(64).all + + # Insert alignment 4 & 2 tests + conv_problems += [ + Conv2DProblemSize( + 1, 4, 4, 12, + 8, 3, 3, 12, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 4, 4, 14, + 8, 3, 3, 14, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 23, 56, 98, + 128, 3, 3, 98, + 4, 5, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + ] + + return conv_problems diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..d892b5df047d5121345d902a77aadf2256b4c3b5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'conv2d_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d4c52a9f75fb4c3bc809947bf48ba85356ec70 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py @@ -0,0 +1,309 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests emitting a CUTLASS kernel to a PyTorch CUDA extension +""" + +import random +import tempfile +import unittest + +from cutlass_library import ConvMode + +import cutlass_cppgen + +if cutlass_cppgen.utils.datatypes.is_torch_available(): + import torch + + +def _initialize(dtype, M: int, N: int, K: int): + """ + Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K + + :param dtype: data type of tensors + :param M: M dimension of GEMM problem + :type M: int + :param N: N dimension of GEMM problem + :type N: int + :param K: N dimension of GEMM problem + :type K: int + + :return: initialized tensors A, B, C, and D + :rtype: list + """ + sizes = [(M, K), (K, N), (M, N), (M, N)] + return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes] + + +def _generate_problems(dtype, num): + """ + Utility function to generate `num` GEMMs of random sizes + + :param dtype: data type of tensors + :param num: number of GEMMs to generate + :type num: int + + :return: lists of A, B, C, and D tensors + :rtype: list + """ + valid_sizes = [128, 256, 512, 1024] + As, Bs, Cs, Ds = [], [], [], [] + for _ in range(num): + M, N, K = [random.choice(valid_sizes) for _ in range(3)] + A, B, C, D = _initialize(dtype, M, N, K) + As.append(A) + Bs.append(B) + Cs.append(C) + Ds.append(D) + return As, Bs, Cs, Ds + +def _generate_conv2d_problem(conv_kind, dtype, ps): + """ + Utility function to generate conv2d inputs + + :param conv_kind: kind of convolution + :type conv_kind: str + :param dtype: data type of tensors + :param problem_size: the conv2d problem size + :type problem_size: cutlass_cppgen.shape.Conv2DProblemSize + + :return: initialized tensors A, B, C, and D + :rtype: list + """ + if conv_kind == "fprop": + tensor_A_size = (ps.N, ps.C, ps.H, ps.W) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.K, ps.P, ps.Q) + elif conv_kind == "dgrad": + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.C, ps.H, ps.W) + else: + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.N, ps.C, ps.H, ps.W) + tensor_C_size = (ps.K, ps.C, ps.R, ps.S) + sizes = [tensor_A_size, tensor_B_size, tensor_C_size] + return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes] + + +@unittest.skipIf(not cutlass_cppgen.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests') +class PyTorchExtensionTest(unittest.TestCase): + + def test_gemm(self): + random.seed(2023) + + dtype = torch.float16 + plan = cutlass_cppgen.op.Gemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor) + op = plan.construct() + + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name='gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) + + A, B, C, _ = _initialize(dtype, 1024, 256, 512) + + D_ref = A @ B + D = mod.run(A, B) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C, 1.0) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C, 1.0, 0.0) + assert torch.allclose(D, D_ref) + + alpha = 2.0 + beta = -1.0 + D_ref = (A @ B) * alpha + (beta * C) + D = mod.run(A, B, C, alpha, beta) + assert torch.allclose(D, D_ref) + + def test_grouped_gemm(self): + random.seed(2023) + + dtype = torch.float16 + plan = cutlass_cppgen.op.GroupedGemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor) + op = plan.construct() + + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name='grouped_gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) + + As, Bs, Cs, _ = _generate_problems(dtype, 50) + + def check_all(X, Y): + for x, y in zip(X, Y): + assert torch.allclose(x, y) + + Ds_ref = [a @ b for a, b in zip(As, Bs)] + Ds = mod.run(As, Bs) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs, 1.0) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs, 1.0, 0.0) + check_all(Ds, Ds_ref) + + alpha = 2.0 + beta = -1.0 + Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)] + Ds = mod.run(As, Bs, Cs, alpha, beta) + check_all(Ds, Ds_ref) + + def test_conv2d_fprop(self): + torch.manual_seed(2023) + + dtype = torch.float16 + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32) + plan.activation = "relu" + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( + 1, 4, 4, 16, + 8, 3, 3, 16, + 0, 0, + 3, 3, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + + D_ref = alpha * torch.ops.aten.conv2d( + A, B, stride=stride, padding=padding + ) + beta * C + D_ref = torch.nn.functional.relu(D_ref) + D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta) + + assert torch.allclose(D, D_ref) + + # Test serial split-K + D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) + assert torch.allclose(D, D_serial_split_k) + + # Test parallel split-K + D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) + assert torch.allclose(D, D_parallel_split_k) + + + def test_conv2d_dgrad(self): + torch.manual_seed(2023) + dtype = torch.float16 + plan = cutlass_cppgen.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32) + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( + 1, 4, 4, 16, + 8, 3, 3, 16, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W) + D_ref = alpha * torch.nn.grad.conv2d_input( + input_size, B, A, + stride=stride, padding=padding + ) + beta * C + D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, ) + + assert torch.allclose(D, D_ref) + + def test_conv2d_wgrad(self): + torch.manual_seed(2023) + dtype = torch.float16 + plan = cutlass_cppgen.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32) + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( + 1, 4, 4, 16, + 8, 3, 3, 16, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S) + D_ref = alpha * torch.nn.grad.conv2d_weight( + B, weight_size, A, + stride=stride, padding=padding + ) + beta * C + D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta) + + assert torch.allclose(D, D_ref) + + # Test serial split-K + D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) + assert torch.allclose(D, D_serial_split_k) + + # Test parallel split-K + D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) + assert torch.allclose(D, D_parallel_split_k) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..5467469e74e05573fb297b009914e0980e5ab222 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py @@ -0,0 +1,198 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Unit test for compute node in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * +from cutlass_cppgen import swizzle + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTCompute(EVTTestCaseBase): + + def test_arith(self): + """ + Test Arithmatic op + """ + def evt_arith_compute(accum, C, alpha, beta, gamma): + D = ((accum + C) * alpha - gamma) / beta + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "gamma": 2.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_arith_compute, example_inputs) + input_keys = ["C", "alpha", "beta", "gamma"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_func_call(self): + """ + Test Function call + """ + def evt_func_call(accum, C, alpha, beta, gamma): + D = multiply_add(relu(accum + alpha) + C, beta, gamma) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "gamma": 2.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_func_call, example_inputs) + input_keys = ["C", "alpha", "beta", "gamma"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_func_call2(self): + """ + Test Function call + """ + + def evt_func_call2(accum, C, alpha, beta): + D = maximum(alpha * accum + beta * C, 0.0) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_func_call2, example_inputs) + input_keys = ["C", "alpha", "beta"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_tanh(self): + """ + Test Tanh op + """ + def evt_tanh(accum): + D = tanh(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_tanh, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_sigmoid(self): + """ + Test Sigmoid op + """ + def evt_sigmoid(accum): + D = sigmoid(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_sigmoid, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_gelu(self): + """ + Test GELU op + """ + def evt_gelu(accum): + D = gelu(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_gelu, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_exp(self): + """ + Test Exp op + """ + def evt_exp(accum): + D = exp(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_exp, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a7b7f7a336dce0651f299d26b17df04952be99 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py @@ -0,0 +1,173 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for store nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTLayout(EVTTestCaseBase): + + def test_permute_1(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(0, 2, 1)) + D_permute = F_permute + permute(C, indices=(0, 2, 1)) + D = permute(D_permute, indices=(0, 2, 1)) + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") + def test_permute_2(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(0, 2, 1)) + D = F_permute + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, n, m)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, n, m)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") + def test_permute_3(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(1, 0, 2)) + D = F_permute + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (m, l, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (m, l, n)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_reshape(self): + """ + Test reshape + """ + def evt_reshape(accum, alpha, TensorE): + F = alpha * accum + E_reshape = reshape(TensorE, new_shape=(512, 1)) + D = F + E_reshape + return D + + example_inputs = { + "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)), + "alpha": 0.5, + "TensorE": self.fake_tensor(self.element, (16, 32)), + "D": self.fake_tensor(self.element, (self.l, self.m, self.n)), + } + + launcher = EVTTestBed(self.element, evt_reshape, example_inputs) + input_keys = ["alpha", "TensorE"] + result_keys = ["D"] + launcher.verify(self.problem_size, input_keys, result_keys, self.l) + + def test_reshape2(self): + """ + Test reshape + """ + def evt_reshape(accum, alpha, TensorE): + F = alpha * accum + F_reshape = reshape(F, new_shape=(2, 3, 512, 256)) + D = F_reshape + TensorE + return D + + example_inputs = { + "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)), + "alpha": 0.5, + "TensorE": self.fake_tensor(self.element, (2, 3, 1, self.n)), + "D": self.fake_tensor(self.element, (2, 3, self.m, self.n)), + } + + launcher = EVTTestBed(self.element, evt_reshape, example_inputs) + input_keys = ["alpha", "TensorE"] + result_keys = ["D"] + launcher.verify(self.problem_size, input_keys, result_keys, self.l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..57a5c6bb17bb44bf294cc7a6a749c706601034f6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py @@ -0,0 +1,142 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for load nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTLoad(EVTTestCaseBase): + + def test_tensor_load(self): + """ + Load extra tensor with shape [m, n] + """ + def evt_tensor_load(accum, C, aux, aux_batch): + D = accum + C + aux + aux_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "aux": self.fake_tensor(self.element, (m, n)), + "aux_batch": self.fake_tensor(np.float32, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_tensor_load, example_inputs) + input_keys = ["C", "aux", "aux_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_row_broadcast(self): + """ + Load extra tensor with shape [1, n] + """ + def evt_row_broadcast(accum, C, bias, bias_batch): + D = accum + C + bias + bias_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (n,)), + "bias_batch": self.fake_tensor(np.float32, (l, 1, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_row_broadcast, example_inputs) + input_keys = ["C", "bias", "bias_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_column_broadcast(self): + """ + Load extra tensor with shape [m, 1] + """ + def evt_column_broadcast(accum, C, bias, bias_batch): + D = accum + C + bias + bias_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (m, 1)), + "bias_batch": self.fake_tensor(np.float32, (l, m, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_column_broadcast, example_inputs) + input_keys = ["C", "bias", "bias_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_scalar_broadcast(self): + """ + Load extra tensor with shape [1, 1] + """ + def evt_scalar_broadcast(accum, C, alpha, alpha_batch): + D = accum + C + alpha + alpha_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "alpha_batch": self.fake_tensor(np.float32, (l, 1, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_scalar_broadcast, example_inputs) + input_keys = ["C", "alpha", "alpha_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..30dc8fe0d5ec413f1da57a8fa0875ed5e7baa887 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py @@ -0,0 +1,319 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unittest for mixed types of nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * +from cutlass_cppgen.swizzle import ThreadblockSwizzleStreamK + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTMixed(EVTTestCaseBase): + + def test_same_variable_used_multiple_times(self): + """ + The same variable z0 is used multiple times + """ + def evt_aux_store(accum): + z0 = relu(accum) + D = z0 + z0 + return z0, D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "z0": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) + input_keys = ["accum"] + result_keys = ["z0", "D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_no_lca(self): + """ + The same variable z0 is used multiple times + """ + def evt_no_lca(accum, bias): + E = relu(accum) + F = E + bias + tmp_2 = E + 2 + D = tmp_2 + E + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (m,1), stride=(1,0)), + } + + launcher = EVTTestBed(self.element, evt_no_lca, example_inputs) + input_keys = ["accum", "bias"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_mixed_dag(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + if device_cc() == 80: + alignments = [2, 4, 8] + else: + # Sm90 EVT currently only supports 128-bit alignment + alignments = [8,] + for align in alignments: + for m, n, k, l in self.get_problem_sizes(align): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_float(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for align in [3, 2, 4]: + for m, n, k, l in self.get_problem_sizes(align): + example_inputs = { + "accum": self.fake_tensor(np.float32, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(np.float32, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(np.float32, (l, m, n)), + "cbias": self.fake_tensor(np.float32, (m, 1)), + "rbias": self.fake_tensor(np.float32, (n,)), + "D": self.fake_tensor(np.float32, (l, m, n)), + "F": self.fake_tensor(np.float32, (l, m, n)), + "F_row_max": self.fake_tensor(np.float32, (n,)), + "E_col_max": self.fake_tensor(np.float32, (m, 1)) + } + launcher = EVTTestBed(DataType.f32, evt_mixed_dag, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_stage2(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, epilogue_stages=2) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_partition_k(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + tile_description = { + "threadblock_shape": [128, 128, 64], + "warp_count": [2, 2, 2] + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, tile_description=tile_description, epilogue_stages=2) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_stream_k(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + # High per-sm occupancy tile_description + tile_description = { + "threadblock_shape": [128, 128, 32], + "warp_count": [2, 2, 1], + "stages": 3 + } + tds = [None, tile_description] + for td in tds: + for m, n, k, l in self.get_problem_sizes(8, k=960, batch_count=[1, 3]): + if l == 1: + example_inputs = { + "accum": self.fake_tensor(self.element, (m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (m, n)), + "F": self.fake_tensor(self.element, (m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + else: + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + if td is not None: + launcher = EVTTestBed( + self.element, evt_mixed_dag, example_inputs, + tile_description=td, + swizzling_functor=ThreadblockSwizzleStreamK, backend="torch") + else: + launcher = EVTTestBed( + self.element, evt_mixed_dag, example_inputs, + swizzling_functor=ThreadblockSwizzleStreamK, backend="torch") + + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_mixed_dag_no_batch(self): + def evt_mixed_dag_no_batch(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, _ in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (m, n)), + "F": self.fake_tensor(self.element, (m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag_no_batch, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, 1) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..b47f11e4f3bde3499948ae68b1b5bb79347f0fd1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py @@ -0,0 +1,180 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for store nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTStore(EVTTestCaseBase): + + @unittest.skipIf(device_cc() != 90, "This test is only for CC 90") + def test_invalid_store(self): + """ + Test invalid store + """ + def evt_invalid_store(accum): + D = accum + F = D + 1 # D has users, which is not allowed on SM90 or higher + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)) + } + with self.assertRaisesRegex( + RuntimeError, + r"On SM90 or higher, D is expected to be a output node with 0 users " + r"to enable smem reuse between C and D, but got 1" + ): + launcher = EVTTestBed(self.element, evt_invalid_store, example_inputs) + + break # Only need to test once + + def test_aux_store(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_aux_store(accum, alpha, C): + F = alpha * accum + D = F + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_col_reduce(self): + """ + Reduction [m, n] -> [m, 1] + """ + def evt_row_reduce(accum, alpha, C): + acc_row_max = max(accum, dim=[2,]) + F = alpha * accum + F_row_max = max(F, dim=[0, 2]) + D = F + C + return D, F_row_max, acc_row_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(np.float32, (m, 1)), + "acc_row_max": self.fake_tensor(np.float32, (l, m, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_row_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_row_max", "acc_row_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_row_reduce(self): + """ + Reduction [m, n] -> [n] + """ + def evt_col_reduce(accum, alpha, C): + acc_col_max = max(accum, dim=[1,]) + F = alpha * accum + F_col_max = max(F, dim=[0, 1]) + D = F + C + return D, F_col_max, acc_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "F_col_max": self.fake_tensor(np.float32, (n,)), + "acc_col_max": self.fake_tensor(np.float32, (l, 1, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_col_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_col_max", "acc_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_scalar_reduce(self): + """ + Reduction [m, n] -> [1,] + """ + def evt_scalar_reduce(accum, alpha, C): + acc_max = max(accum, dim=[1, 2]) + F = alpha * accum + F_max = max(F, dim=[0, 1, 2]) + D = F + C + return D, F_max, acc_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "acc_max": self.fake_tensor(np.float32, (l, 1, 1)), + "F_max": self.fake_tensor(np.float32, (1,)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_scalar_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_max", "acc_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb84e2e8c85e602b45b9ee18ce324accd3a32cd --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'evt_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py new file mode 100644 index 0000000000000000000000000000000000000000..62d375d856ffaef6be50b39b76121e0eb78a7465 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py @@ -0,0 +1,235 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Testbed classes of EVT +""" + +import torch +import unittest + +import cutlass_cppgen +from cutlass_cppgen import Tensor +import cutlass_cppgen.backend.evt +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils.datatypes import torch_type +from cutlass_cppgen.utils.profiler import CUDAEventProfiler + + +class EVTReferenceModule: + def __init__(self, layout_A, layout_B, layout_C, epilogue_visitor): + self.layout_A = layout_A + self.layout_B = layout_B + self.layout_C = layout_C + self.epilogue_visitor = epilogue_visitor + + def run(self, A, B, C, problem_size, alpha, beta, batch=1): + if self.layout_A == cutlass_cppgen.LayoutType.RowMajor: + A_row = A.view((batch, problem_size.m, problem_size.k)) + else: + A_col = A.view((batch, problem_size.k, problem_size.m)) + A_row = torch.permute(A_col, (0, 2, 1)) + + if self.layout_B == cutlass_cppgen.LayoutType.RowMajor: + B_row = B.view((batch, problem_size.k, problem_size.n)) + else: + B_col = B.view((batch, problem_size.n, problem_size.k)) + B_row = torch.permute(B_col, (0, 2, 1)) + + if self.layout_C == cutlass_cppgen.LayoutType.RowMajor: + C_row = C.view((batch, problem_size.m, problem_size.n)) + else: + C_col = C.view((batch, problem_size.n, problem_size.m)) + C_row = torch.permute(C_col, (0, 2, 1)) + + out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta + + if self.layout_C == cutlass_cppgen.LayoutType.ColumnMajor: + out = torch.permute(out_row, (0, 2, 1)) + else: + out = out_row + + return torch.flatten(out) + + def __call__(self, A, B, C, problem_size, batch=1, epilogue_args=None): + # Running the mainloop + accum = self.run( + A, B, C, problem_size, 1.0, 0.0, batch=batch + ).reshape(batch, problem_size.m, problem_size.n) + + # Running the epilogue + epilogue_args["accum"] = accum + references = self.epilogue_visitor(**epilogue_args) + + # Return the results + if not isinstance(references, tuple): + references = (references,) + return references + + +class EVTTestBed: + """ + Epilogue Visitor Testbed + """ + def __init__(self, element, evt_fn, example_inputs, profile=False, **kwargs) -> None: + self.element = element + layout = cutlass_cppgen.LayoutType.RowMajor + self.example_inputs = example_inputs + + # Create the Gemm plan + self.plan = cutlass_cppgen.op.Gemm(element=element, layout=layout, element_accumulator=torch.float32) + + if "tile_description" in kwargs: + self.plan.tile_description = kwargs["tile_description"] + + if "swizzling_functor" in kwargs: + self.plan.swizzling_functor = kwargs["swizzling_functor"] + + # Compile the epilogue visitor + epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_fn, example_inputs) + if "epilogue_stages" in kwargs: + epilogue_visitor.epilogue_stages = kwargs["epilogue_stages"] + self.plan.epilogue_visitor = epilogue_visitor + + # Reference model + self.reference_fn = EVTReferenceModule(layout, layout, layout, epilogue_visitor) + + self.profile = profile + + def get_torch_tensor(self, shape, dtype=None, fill=None): + if dtype is None: + dtype = self.element + + dtype = torch_type(dtype) + if fill is None: + return torch.ceil( + torch.empty(size=shape, dtype=dtype, device="cuda").uniform_(-4.5, 3.5) + ) + else: + return torch.full(shape, fill, dtype=dtype, device="cuda") + + def verify(self, problem_size, input_keys, result_keys, batch_count=1): + """ + Verify the results + """ + problem_size = GemmCoord(*problem_size) + + # Initiate the GEMM arguments + tensor_A = self.get_torch_tensor((batch_count, problem_size.m, problem_size.k)) + tensor_B = self.get_torch_tensor((batch_count, problem_size.k, problem_size.n)) + + # Initialize the epilogue args + epilogue_args = {} + for key in self.example_inputs.keys(): + if key in input_keys: + tensor = self.example_inputs[key] + if isinstance(tensor, Tensor): + epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element) + else: + epilogue_args[key] = tensor + elif key in result_keys: + tensor = self.example_inputs[key] + if isinstance(tensor, Tensor): + if "max" in key: + fill = -1000 + else: + fill = 0 + epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element, fill=fill) + else: + epilogue_args[key] = tensor + + tensor_D = epilogue_args["D"] + if "C" in epilogue_args: + tensor_C = epilogue_args["C"] + else: + tensor_C = tensor_D + # Run the device kernel + self.plan.run(tensor_A, tensor_B, tensor_C, tensor_D, visitor_args=epilogue_args) + + # Run the host reference + evt_args_inputs = {} + for key in input_keys: + evt_args_inputs[key] = epilogue_args[key] + + reference_results = self.reference_fn( + tensor_A, tensor_B, tensor_C, problem_size, batch_count, evt_args_inputs) + + # Compare the results + for result, ref in zip(result_keys, reference_results): + assert torch.equal( + epilogue_args[result].flatten(), + ref.masked_fill(torch.isnan(ref), float('inf')).flatten()) + + # Run profile + if self.profile: + profiler = CUDAEventProfiler( + self.plan, 100, 100, tensor_A, tensor_B, tensor_C, tensor_D, + visitor_args = epilogue_args + ) + print(f"Cutlass Python Duration: {profiler()}") + + +class EVTTestCaseBase(unittest.TestCase): + """ + Base class for EVT Unittest + """ + def __init__(self, methodName: str = "runTest", lmnk=(6, 512, 256, 128)) -> None: + super().__init__(methodName) + + self.element = cutlass_cppgen.DataType.f16 + self.l, self.m, self.n, self.k = lmnk + + self.problem_size = (self.m, self.n, self.k) + + torch.random.manual_seed(42) + + def fake_tensor(self, element, shape, stride=None): + if stride is None: + return Tensor(element=element, shape=shape, layout_tag=cutlass_cppgen.LayoutType.RowMajor) + else: + return Tensor(element=element, shape=shape, stride=stride) + + def get_problem_sizes(self, alignment, k=None, batch_count=[3,]): + k = k if k else self.k + problem_size_m = [alignment, 512 - 3 * alignment] + problem_size_n = [alignment, 512 - alignment] + if alignment % 8 == 0: + problem_size_m.append(768) + problem_size_n.append(768) + problem_size_l = batch_count + problem_sizes = [] + for m in problem_size_m: + for n in problem_size_n: + for l in problem_size_l: + problem_sizes.append((m, n, k, l)) + + return problem_sizes diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py new file mode 100644 index 0000000000000000000000000000000000000000..155426ab902d1f99eafc7b03c388fc79b4520317 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py @@ -0,0 +1,134 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +High-level tests for running batched GEMMs +""" + +from functools import partial +import logging +from math import prod +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc +import torch + +from utils import LayoutCombination + +cutlass_cppgen.set_log_level(logging.WARNING) + +torch.manual_seed(2023) + + +def pytorch_reference(A, B, C, alpha, beta): + # Get the batch count. Assume that any of A, B, and C + # with a batch dimension ahve matching batch count. Thus, + # we break out of the loop once we have found the first + # tensor containing a batch dimension. + batch_count = (1,) + for tensor in [A, B, C]: + if len(tensor.shape) > 2: + batch_count = tensor.shape[:-2] + break + + int_batch_count = prod(batch_count) + + def add_batch(tensor): + if len(tensor.shape) == 2: + return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1) + else: + return tensor.reshape(-1, tensor.size(-2), tensor.size(-1)) + + # Reshape tensors to have batch dimension + A = add_batch(A) + B = add_batch(B) + C = add_batch(C) + + ret = (torch.bmm(A, B) * alpha) + (C * beta) + reshape_vals = batch_count + C.shape[-2:] + return ret.reshape(*reshape_vals) + + +def initialize(rows, cols, batch): + tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half() + if len(batch) > 0 and prod(batch) > 1: + reshape_vals = batch + (rows, cols) + return tensor.reshape(*reshape_vals) + else: + return tensor.reshape(rows, cols) + + +class GemmF16Batched(unittest.TestCase): + def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool): + M = 512 + N = 256 + K = 128 + alpha = 1. + beta = 2. + + A = initialize(M, K, batch_count if batch_A else (1,)) + B = initialize(K, N, batch_count if batch_B else (1,)) + C = initialize(M, N, batch_count if batch_C else (1,)) + D = initialize(M, N, batch_count) + + plan = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass_cppgen.DataType.f32) + plan.run(A, B, C, D, alpha, beta) + reference = pytorch_reference(A, B, C, alpha, beta) + assert reference.equal(D) + + def test_batched_ABC(self): + self.run_batched((3,), True, True, True) + self.run_batched((2, 3), True, True, True) + + def test_batched_AB(self): + self.run_batched((3,), True, True, False) + self.run_batched((2, 3), True, True, False) + + def test_batched_AC(self): + self.run_batched((3,), True, False, True) + self.run_batched((2, 3), True, False, True) + + def test_batched_BC(self): + self.run_batched((3,), False, True, True) + self.run_batched((2, 3), False, True, True) + + def test_batched_A(self): + self.run_batched((3,), True, False, False) + self.run_batched((2, 3), True, False, False) + + def test_batched_B(self): + self.run_batched((3,), False, True, False) + self.run_batched((2, 3), False, True, False) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd26951ec5d8a1eb6cbe38491c64fde2873b9c3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py @@ -0,0 +1,128 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F16 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f16 + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..61aa295b966daf5943e7092572c98ee20143e2b5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py @@ -0,0 +1,146 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F16 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.f16 + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype, + warp_count=None, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_unit_cluster = partial(add_test_tensorop, cluster_shape=[1, 1, 1]) +add_test_unit_cluster(layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=3) +add_test_unit_cluster(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], stages=5) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) + +# Tests with different cluster shapes +add_test_cluster_shape = partial(add_test_tensorop, threadblock_shape=[64, 128, 64], stages=None) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 1, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 2, 1]) + +# Tests for different schedule modes +add_test_schedule = partial(add_test_specialized, layouts=LayoutCombination.TTN, alignments=[8, 8, 4], + element_output=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, + opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], stages=None) +add_test_schedule( + cluster_shape=[1, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized +) +add_test_schedule( + cluster_shape=[1, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative +) +add_test_schedule( + cluster_shape=[2, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized +) +add_test_schedule( + cluster_shape=[2, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative +) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], stages=2) +add_test_simt(layouts=LayoutCombination.NNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8]) +add_test_simt(layouts=LayoutCombination.TNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8]) +add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8]) +add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8]) +add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8]) + +# Tests with void-C kernels +add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None, + cluster_shape=[2, 1, 1], element_C=cutlass_cppgen.DataType.void) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..bf662b9208ab2a5343d0fd11106835b7d9a5b2e9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py @@ -0,0 +1,104 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F32 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f32 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF32Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF32Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4) +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..3075ddf74bf2a119759ca1a3e47c0815f4b0923c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py @@ -0,0 +1,103 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F64 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f64 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf36fc77436fef22882e98c752b7a599cf7fb95 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py @@ -0,0 +1,71 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F64 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.f64 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], + element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc']) + +add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.NNN, threadblock_shape=[128, 128, 8], stages=2) +add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.TTT, threadblock_shape=[ 64, 128, 8], stages=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..fef6d457a6528a61613d1295877a2b6b8f80fef5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py @@ -0,0 +1,112 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.e4m3 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF8E4M3Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Test with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with different cluster shapes +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with warp-specialized ping-pong schedule +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) + +# Tests for SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) + + +# +# Add a test for E5M2 +# +dtype = cutlass_cppgen.DataType.e5m2 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF8E5M2Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..0a002a5fbad80de5f7b29e42db0806469244914c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py @@ -0,0 +1,75 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with mixed operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype =cutlass_cppgen.DataType.f16 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmMixedSm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1], + opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass_cppgen.DataType.f32) + +# Test with upcast on A +add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN) + +# Test with upcast on B +add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..e226e23684147cb0a9cd5c1270468eb96c67ba15 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py @@ -0,0 +1,103 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.s8 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0101f78da3b62b599a5deeb89f5596a7e515ce --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py @@ -0,0 +1,98 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.s8 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 8], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 64, 32], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[ 4, 4, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with different cluster shapes +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with warp-specialized ping-pong schedule +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) + +# Tests for SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py new file mode 100644 index 0000000000000000000000000000000000000000..6ffda5b47e37f184c2352f0ee4e737635dbd4147 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py @@ -0,0 +1,423 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from math import prod +import os +import re +import subprocess + +import torch + +from cutlass_library import ( + DataType, + DataTypeSize, + GemmUniversalMode, + LayoutType, + OpcodeClass, + ShortDataTypeNames, + SwizzlingFunctor +) + +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass_cppgen.backend.reduction_operation import ReductionArguments, ReductionOperation +from cutlass_cppgen.shape import GemmCoord, MatrixCoord +from cutlass_cppgen.utils.datatypes import torch_type + + +class GemmUniversalLauncher: + def __init__( + self, + operation, + seed=2080, + verification=True, + iterations=500, + compiler_mode= "nvcc", + **kwargs, + ) -> None: + self.math_operation = operation.tile_description.math_instruction.math_operation + self.verification = verification + + if compiler_mode == "nvcc": + compiler.nvcc() + elif compiler_mode == "nvrtc": + compiler.nvrtc() + else: + raise Exception(f"Unexpected compiler string {compiler_mode}") + + op_list = [operation] + if operation.arch < 90: + # Split K via Python is currently only supported for pre-SM90 kernels + self.reduction_operation: ReductionOperation = ReductionOperation( + shape=MatrixCoord(4, 32 * operation.C.alignment), + C=operation.C, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_compute=operation.epilogue_functor.element_epilogue, + epilogue_functor=operation.epilogue_functor, + count=operation.C.alignment, + ) + op_list.append(self.reduction_operation) + + compiler.add_module(op_list, bypass_cache=False) + + self.operation = operation + + self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element) + self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element) + self.dtype_C = torch_type(operation.C.element) + self.dtype_D = torch_type(operation.epilogue_functor.element_output) + + element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]) + + if element_size == 1: + self.rand_max = 1 + self.rand_min = 0 + elif element_size <= 8: + self.rand_max = 1 + self.rand_min = -1 + elif element_size == 16: + self.rand_max = 4 + self.rand_min = -4 + else: + self.rand_max = 8 + self.rand_min = -8 + + self.seed = seed + + self.compute_type = operation.epilogue_functor.element_epilogue + self.accumulator_type = operation.tile_description.math_instruction.element_accumulator + + def print_problem_size(self, p, mode, batch_count): + if mode == GemmUniversalMode.Gemm: + mode = "Gemm" + elif mode == GemmUniversalMode.Batched: + mode = "GemmBatched" + elif mode == GemmUniversalMode.GemmSplitKParallel: + mode = "GemmSplitKParallel" + print(f"problem: {p.m}, {p.n}, {p.k}\n batch_count: {batch_count}\n mode: {mode}") + + def uniform_init(self, shape, dtype, layout): + size = prod(shape) + if dtype.is_floating_point: + # Initialize data in FP32 and call convert to the data type we desire. + # This is a workaround for the following error that occurs when attempting to + # call uniform_ on a tensor with torch.float8_e4m3fn data: + # RuntimeError: "check_uniform_bounds" not implemented for 'Float8_e4m3fn' + data = torch.ceil( + torch.empty(size=(size,), dtype=torch.float32, device="cuda").uniform_( + self.rand_min - 0.5, self.rand_max - 0.5) + ).to(dtype) + else: + # PyTorch does not currently support integer-typed matrix multiplications on GPU. + # Fall back to CPU for integer type references. + data = torch.empty(size=(size,), dtype=dtype, device="cpu").random_(self.rand_min, self.rand_max + 1) + + is_fp8 = dtype == getattr(torch, "float8_e4m3fn", -1) or dtype == dtype == getattr(torch, "float8_e5m2", -1) + + if dtype == torch.float64 or dtype == torch.float32 or is_fp8: + data = data.to("cpu") + + data_ref = data.reshape(shape) + + if layout == LayoutType.RowMajor: + data_cutlass = data_ref + else: + data_cutlass = data_ref.transpose(-1, -2).contiguous() + + data_cutlass = data_cutlass.to("cuda") + + # As of this writing, few operations in PyTorch are supported with FP8 data. + # Thus, we perform computation in FP32 for FP8 reference checks. + if is_fp8: + data_ref = data_ref.to(torch.float32) + + return data_cutlass, data_ref + + def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): + # If any tensor is on CPU, place all tensors on CPU unless only + # tensor C is on CPU + # Handle mixed-input cases by casting to the larger data type and overriding + # to whatever the data type of the larger type is + if self.dtype_A != self.dtype_B: + if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]: + tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device) + else: + tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device) + + devices = [x.device.type for x in [tensor_A, tensor_B]] + if tensor_C is not None: + devices.append(tensor_C.device.type) + + if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]: + device = torch.device("cpu") + else: + device = tensor_A.device + + tensor_A = tensor_A.to(device) + tensor_B = tensor_B.to(device) + if tensor_C is not None: + tensor_C = tensor_C.to(device) + + dtype = torch_type(self.compute_type) + alpha_torch = torch.tensor([alpha], device=device).to(dtype) + beta_torch = torch.tensor([beta], device=device).to(dtype) + + tmp = tensor_A @ tensor_B + tensor_D_ref = (alpha_torch * tmp) + if tensor_C is not None: + tensor_D_ref += (tensor_C * beta_torch) + return tensor_D_ref.to(self.dtype_D) + + def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0): + torch.random.manual_seed(self.seed) + + # Assign an actual batch count in cases where we are not running in batched mode. + # This is to differentiate between the number of split K slices and the batch count, + # which are overloaded within the single `batch_count` variable. + if mode == GemmUniversalMode.Batched: + true_batch_count = batch_count + else: + true_batch_count = 1 + + def transpose(layout): + if layout == LayoutType.RowMajor: + return LayoutType.ColumnMajor + else: + return LayoutType.RowMajor + + tensor_A, tensor_A_ref = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.k), + self.dtype_A, + self.operation.A.layout if not self.operation.switched else transpose(self.operation.B.layout), + ) + tensor_B, tensor_B_ref = self.uniform_init( + (true_batch_count, problem_size.k, problem_size.n), + self.dtype_B, + self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout), + ) + if self.dtype_C is not None: + tensor_C, tensor_C_ref = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.n), + self.dtype_C, + self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), + ) + else: + tensor_C = None + tensor_C_ref = None + + tensor_D, _ = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.n), + self.dtype_D, + self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), + ) + tensor_D = torch.zeros_like(tensor_D) + + if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]: + alpha = int(alpha) + beta = int(beta) + + # + # Launch kernel + # + + arguments = GemmArguments( + operation=self.operation, + problem_size=problem_size, + A=tensor_A, + B=tensor_B, + C=tensor_C, + D=tensor_D, + output_op=self.operation.epilogue_type(alpha, beta), + gemm_mode=mode, + split_k_slices=split_k_slices, + batch=batch_count, + ) + + if mode == GemmUniversalMode.GemmSplitKParallel: + reduction_arguments = ReductionArguments( + self.reduction_operation, + problem_size=[problem_size.m, problem_size.n], + partitions=split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op=self.reduction_operation.epilogue_type(alpha, beta), + ) + + self.operation.run(arguments) + + if mode == GemmUniversalMode.GemmSplitKParallel: + self.reduction_operation.run(reduction_arguments) + + passed = True + + if self.verification: + if mode == GemmUniversalMode.GemmSplitKParallel: + reduction_arguments.sync() + + # Free memory allocated by args because we are not + # calling `arguments.sync()` in this case (which will free memory) + arguments.free() + else: + arguments.sync() + tensor_D_ref = self.reference( + problem_size, + tensor_A_ref, + tensor_B_ref, + tensor_C_ref, + alpha, + beta, + ) + + tensor_D_ref = tensor_D_ref.to('cuda') + + if self.operation.switched or self.operation.C.layout == LayoutType.ColumnMajor: + tensor_D = tensor_D.transpose(-1, -2).contiguous() + + passed = tensor_D.equal(tensor_D_ref) + + try: + assert passed + except AssertionError: + self.print_problem_size(problem_size, mode, batch_count) + del arguments + if mode == GemmUniversalMode.GemmSplitKParallel: + del reduction_arguments + + return passed + + +def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"): + passed = True + + minimum_operand_element_size = min( + DataTypeSize[operation.A.element], DataTypeSize[operation.B.element] + ) + opcode_class = operation.tile_description.math_instruction.opcode_class + + if opcode_class == OpcodeClass.Simt: + alignment = 1 + else: + alignment = 128 // minimum_operand_element_size + + alignment_m = alignment + alignment_n = alignment + alignment_k = alignment + + # INT8 alignment constraints + if opcode_class == OpcodeClass.Simt: + A_is_s8 = operation.A.element == DataType.s8 + B_is_s8 = operation.B.element == DataType.s8 + + if A_is_s8 and operation.A.layout == LayoutType.ColumnMajor: + alignment_m = 4 + if B_is_s8 == DataType.s8 and operation.A.layout == LayoutType.RowMajor: + alignment_n = 4 + if A_is_s8 and B_is_s8 and (operation.A.layout == LayoutType.RowMajor or operation.B.layout == LayoutType.ColumnMajor): + alignment_k = 4 + + threadblock_k = operation.tile_description.threadblock_shape[2] + + assert testcase != "interleaved" + + supports_split_k = operation.arch < 90 and not operation.swizzling_functor == SwizzlingFunctor.StreamK + + if testcase == "multistage": + modes = [GemmUniversalMode.Gemm] + problem_size_m = [16, 528] + problem_size_n = [16, 528] + problem_size_k = [ + threadblock_k, + threadblock_k * operation.tile_description.stages + + operation.tile_description.math_instruction.instruction_shape[2], + ] + problem_alpha = [1.0] + problem_beta = [0.0] + batch_counts = [1] + else: + modes = [GemmUniversalMode.Gemm] + batch_counts = [1, 2, 3, 5, 7] + if supports_split_k: + modes.append(GemmUniversalMode.GemmSplitKParallel) + + problem_size_m = [alignment_m, 512 - 3 * alignment_m] + problem_size_n = [alignment_n, 512 - 2 * alignment_n] + if operation.tile_description.stages is None: + stages_for_k_calc = 7 + else: + stages_for_k_calc = operation.tile_description.stages + problem_size_k = [ + alignment_k, + threadblock_k * stages_for_k_calc - alignment_k, + threadblock_k * stages_for_k_calc * 3 - alignment_k, + ] + problem_alpha = [1.0] + problem_beta = [2.0] + + testbed = GemmUniversalLauncher(operation, compiler_mode=compilation_mode) + + for mode in modes: + for m in problem_size_m: + for n in problem_size_n: + for k in problem_size_k: + for batch_count in batch_counts: + for alpha in problem_alpha: + for beta in problem_beta: + # skip very small K problems + if testcase == "universal": + if k // batch_count < 2 * threadblock_k: + continue + + problem_size = GemmCoord(m, n, k) + + if supports_split_k: + split_k_slices = batch_count + else: + split_k_slices = 1 + + overridden_mode = mode + if mode == GemmUniversalMode.Gemm and batch_count > 1: + overridden_mode = GemmUniversalMode.Batched + + passed = testbed.run( + overridden_mode, + problem_size, + batch_count, + split_k_slices, + alpha, + beta, + ) + + if not passed: + return False + + return passed diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5e7467b1e0040ce3012ff8541dfbac381bb861 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'gemm_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..28bba3e922961c96df75f8685e3064ab55cbbc87 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py @@ -0,0 +1,260 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_library import SubstituteTemplate + +import cutlass_cppgen +from cutlass_library import ( + DataTypeNames, + EpilogueScheduleSuffixes, + KernelScheduleSuffixes, + LayoutType, + OpcodeClassNames, + ShortDataTypeNames, + ShortLayoutTypeNames +) +from cutlass_cppgen.backend import library + +from gemm_testbed import test_all_gemm + + +class Layout: + """ + Utility class to map transpose and non-transpose terminology to row- and column-major terminology + """ + + T = LayoutType.RowMajor + N = LayoutType.ColumnMajor + + +class LayoutCombination: + """ + Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs + """ + + NNN = (Layout.N, Layout.N, Layout.N) + NNT = (Layout.N, Layout.N, Layout.T) + NTN = (Layout.N, Layout.T, Layout.N) + NTT = (Layout.N, Layout.T, Layout.T) + TNN = (Layout.T, Layout.N, Layout.N) + TNT = (Layout.T, Layout.N, Layout.T) + TTN = (Layout.T, Layout.T, Layout.N) + TTT = (Layout.T, Layout.T, Layout.T) + + +def get_name( + layouts, + alignments, + element_output, + element_accumulator, + element_epilogue, + cluster_shape, + threadblock_shape, + stages, + element_a, + element_b, + element_c, + arch, + opclass, + kernel_schedule=None, + epilogue_schedule=None, + suffix="", +): + """ + Generates a procedural name for a test case. + + :param layouts: indexable container of layouts of A, B, and C operands + :param alignments: indexable container of alignments of A, B, and C operands + :param element_output: data type of the output element + :param element_accumulator: data type used in accumulation + :param element_epilogue: data type used in computing the epilogue + :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_c: data type of operand C + :param arch: compute capability of kernel being generated + :type arch: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param kernel_schedule: kernel_schedule type + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue_schedule type + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param suffix: additional string to add to the suffix of the name + :type suffix: str + + :return: str + """ + name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}" + return SubstituteTemplate( + name_format, + { + "arch": str(arch), + "eA": DataTypeNames[element_a], + "eB": DataTypeNames[element_b], + "eC": DataTypeNames[element_c], + "lA": ShortLayoutTypeNames[layouts[0]], + "lB": ShortLayoutTypeNames[layouts[1]], + "lC": ShortLayoutTypeNames[layouts[2]], + "opclass": OpcodeClassNames[opclass], + "acc": DataTypeNames[element_accumulator], + "cM": str(cluster_shape[0]), + "cN": str(cluster_shape[1]), + "cK": str(cluster_shape[2]), + "tbM": str(threadblock_shape[0]), + "tbN": str(threadblock_shape[1]), + "tbK": str(threadblock_shape[2]), + "stages": str(stages) if stages is not None else "auto", + "aA": str(alignments[0]), + "aB": str(alignments[1]), + "aC": str(alignments[2]), + "k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule], + "e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule], + "suffix": "" if suffix is None else suffix, + }, + ) + + +def add_test_gemm( + cls=None, + cc=None, + element=None, + layouts=None, + alignments=None, + element_output=None, + element_accumulator=None, + cluster_shape=None, + threadblock_shape=None, + warp_count=None, + stages=None, + opclass=None, + swizzle=None, + kernel_schedule=None, + epilogue_schedule=None, + compilation_modes=['nvcc', 'nvrtc'], + element_A=None, + element_B=None, + element_C=None): + """ + Create test-running functions with the given specification and set it as a method of ``cls``. + + :param cls: class to which the generated method will be added + :type cls: type + :param cc: compute capability to compile for + :type cc: int + :param element: data type of A and B operands + :type element: cutlass_cppgen.DataType.f16 + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass_cppgen.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass_cppgen.DataType + :param cluster_shape: dimensions of clusters + :type cluster_shape: list or tuple + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param swizzle: threadblock swizzling functor + :param kernel_schedule: kernel schedule to use + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue schedule to use + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc') + :type compilation_modes: list, + :param element_A: data type of operand A. If set, overrides ``element`` + :type element_A: cutlass_cppgen.DataType + :param element_B: data type of operand B. If set, overrides ``element`` + :type element_B: cutlass_cppgen.DataType + :param element_C: data type of operand C. If set, overrides ``element`` + :type element_C: cutlass_cppgen.DataType + """ + + if element_A is None: + element_A = element + if element_B is None: + element_B = element + if element_C is None: + element_C = element + if element_output is None: + element_output = element + if element_accumulator is None: + element_accumulator = element + + for compilation_mode in compilation_modes: + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator, + kernel_cc=cc) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + + td = plan.tile_descriptions()[0] + + if warp_count is not None: + td.warp_count = warp_count + td.threadblock_shape = threadblock_shape + td.stages = stages + td.cluster_shape = cluster_shape + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode)) + + element_epilogue = element_accumulator + name = get_name( + layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator, + element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape, + stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass, + kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}') + + setattr(cls, name, run) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/installation.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/installation.py new file mode 100644 index 0000000000000000000000000000000000000000..f550c394812c7fede55070e4c99c4471a69c2f88 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/installation.py @@ -0,0 +1,57 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests for a successful installation of the CUTLASS Python interface +""" + +import os +import unittest + +import cutlass_cppgen +import cutlass_library + + +class InstallationTest(unittest.TestCase): + def test_cutlass_source_paths(self): + """ + Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages + """ + src_file = 'include/cutlass/cutlass.h' + library_file = os.path.join(cutlass_library.source_path, src_file) + cutlass_file = os.path.join(cutlass_cppgen.CUTLASS_PATH, src_file) + assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded." + assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded." + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5d46d45d617198a46bec85cd7218cb5431a7b1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py @@ -0,0 +1,284 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests the high-level Conv2d interface +""" + +from math import ceil +import unittest + +import cutlass_cppgen +import cutlass_cppgen.utils.datatypes as datatypes +from cutlass_cppgen.backend.utils.device import device_cc +from utils import ExpectException +import os + + +class Conv2dEquivalence: + """ + Helper class for testing the equivalence of different constructions of the Conv2d interface + """ + def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator, + alignment_A, alignment_B, alignment_C): + + self.element_A = element_A + self.element_B = element_B + self.element_C = element_C + self.element_D = element_D + self.element_accumulator = element_accumulator + self.alignment_A = alignment_A + self.alignment_B = alignment_B + self.alignment_C = alignment_C + + self.conv_kind = conv_kind + + self.plan = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C, + element_D=element_D, element_accumulator=element_accumulator) + + self.op = self.plan.construct( + alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_C=self.alignment_C) + + def _plans_equal(self, other_plan) -> bool: + """ + Compares whether two plans are equal + + :param other_plan: plan to compare against the default Conv2d + :type other_plan: cutlass_cppgen.op.Conv2d + + :return: whether `other_plan` is equivalent to `self.plan` + :rtype: bool + """ + other_op = other_plan.construct( + alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_C=self.alignment_C) + + return self.op.rt_module.emit() == other_op.rt_module.emit() + + def generic_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types + and layouts for constructing the Conv2d interface + """ + if not datatypes.is_numpy_available(): + return + + # Test when specifying all parameters + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A and B as tensors using generic element and output + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test without explicit accumulator. Only run if the type of C and the accumulator are equal + if self.element_C == self.element_accumulator: + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_C=self.element_C, + element_D=self.element_D, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test with only the generic types. Only rune if the types of A, B, C, and D are the same + if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D + and self.element_A == self.element_accumulator): + plan_other = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=self.element_A) + assert self._plans_equal(plan_other) + + def numpy_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend + """ + if not datatypes.is_numpy_available(): + return + + import numpy as np + type_A = datatypes.numpy_type(self.element_A) + type_B = datatypes.numpy_type(self.element_B) + type_C = datatypes.numpy_type(self.element_C) + type_D = datatypes.numpy_type(self.element_D) + type_accum = datatypes.numpy_type(self.element_accumulator) + + size = (2, 2) + A = np.zeros(size, dtype=type_A) + B = np.zeros(size, dtype=type_B) + C = np.zeros(size, dtype=type_C) + D = np.zeros(size, dtype=type_D) + + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) + + def torch_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend + """ + if not datatypes.is_torch_available(): + return + + import torch + type_A = datatypes.torch_type(self.element_A) + type_B = datatypes.torch_type(self.element_B) + type_C = datatypes.torch_type(self.element_C) + type_D = datatypes.torch_type(self.element_D) + type_accum = datatypes.torch_type(self.element_accumulator) + + size = (2, 2) + + A = torch.empty(size, dtype=type_A) + B = torch.empty(size, dtype=type_B) + C = torch.empty(size, dtype=type_C) + D = torch.empty(size, dtype=type_D) + + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) + + def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D): + # Test when specifying all parameters via tensors + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A as tensors + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + if type_A == type_B: + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A) + assert self._plans_equal(plan_np) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if type_C == type_accum: + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D) + assert self._plans_equal(plan_np) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum): + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=type_A) + assert self._plans_equal(plan_np) + + def test_all(self): + """ + Runs all tests on the Gemm interface + """ + self.generic_test() + self.numpy_test() + self.torch_test() + + +@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') +class ConvEquivalenceTest(unittest.TestCase): + """ + Tests the equivalence of different constructions of the Conv2d interface + """ + pass + +type2alignment = { + cutlass_cppgen.DataType.f16: 8, + cutlass_cppgen.DataType.f32: 4 +} + +def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator): + + test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}" + + def run(self): + conv2d_eq = Conv2dEquivalence( + conv_kind=conv_kind, + element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_D, + element_accumulator=element_accumulator, + alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B], + alignment_C=type2alignment[element_C] + ) + conv2d_eq.test_all() + + setattr(ConvEquivalenceTest, test_name, run) + +for conv_kind in ["fprop", "wgrad", "dgrad"]: + for types in [ + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32], + [cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32] + ]: + add_test(conv_kind, types[0], types[1], types[2], types[3], types[4]) + + +@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') +class Conv2dErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the high-level Gemm interface + """ + + def test_alignment(self): + """ + Tests case in which the alignment specified is unsupported + """ + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) + + with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'): + op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3) + + def test_invalid_tile_description(self): + """ + Tests scenarios in which an invalid tile description is provided for a given CC + """ + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) + + td = plan.tile_descriptions()[0] + td.threadblock_shape=[17, 32, 5] + + plan.tile_description = td + with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'): + plan.compile() + # Clean up the error message + os.remove("./cutlass_python_compilation_device_error.txt") + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..e7d67f4d07f01b0936ff5796bfb6fe4c98b5c031 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py @@ -0,0 +1,254 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Test the EVT interface +""" + +import numpy as np +import unittest + +import cutlass_cppgen +from cutlass_cppgen import LayoutType, Tensor +from cutlass_cppgen.backend.utils.device import device_cc +from cutlass_cppgen.epilogue import reshape, permute + +from utils import ExpectException + + +@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only") +class EVTErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the EVT interface + """ + @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT requires root node be 'D'") + def test_root_not_d(self): + """ + Test when "D" does not exist in Sm90 EVT + """ + def evt_root_not_d(accum, alpha): + F = accum * alpha + return F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.2, + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(device_cc() == 90, + "SyntaxError: Sm90 EVT requires the epilogue to have a returned tensor D, " + "but the variable 'D' is not found in the return values.", True): + + cutlass_cppgen.epilogue.trace(evt_root_not_d, example_tensors) + + def test_no_accum(self): + """ + Test when "accum" is not in input arguments + """ + def evt_no_accum(alpha, C): + D = alpha * C + return D + + example_tensors = { + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.2, + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Cannot find 'accum' in the argument list.", True): + cutlass_cppgen.epilogue.trace(evt_no_accum, example_tensors) + + @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT has concern on smem size") + def test_too_much_shared_memory(self): + """ + Test when the epilogue consumes too much shared memory + """ + def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8): + D1 = accum + C1 + D2 = D1 + C2 + D3 = D2 + C3 + D4 = D3 + C4 + D5 = D4 + C5 + D6 = D5 + C6 + D7 = D6 + C7 + D = D7 + C8 + return D, D1, D2, D3, D4, D5, D6, D7 + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C1": self.fake_tensor(np.float16, (6, 512, 512)), + "C2": self.fake_tensor(np.float16, (6, 512, 512)), + "C3": self.fake_tensor(np.float16, (6, 512, 512)), + "C4": self.fake_tensor(np.float16, (6, 512, 512)), + "C5": self.fake_tensor(np.float16, (6, 512, 512)), + "C6": self.fake_tensor(np.float16, (6, 512, 512)), + "C7": self.fake_tensor(np.float16, (6, 512, 512)), + "C8": self.fake_tensor(np.float16, (6, 512, 512)), + "D1": self.fake_tensor(np.float16, (6, 512, 512)), + "D2": self.fake_tensor(np.float16, (6, 512, 512)), + "D3": self.fake_tensor(np.float16, (6, 512, 512)), + "D4": self.fake_tensor(np.float16, (6, 512, 512)), + "D5": self.fake_tensor(np.float16, (6, 512, 512)), + "D6": self.fake_tensor(np.float16, (6, 512, 512)), + "D7": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_too_much_shared_memory, example_tensors) + + plan = cutlass_cppgen.op.Gemm( + element=np.float16, layout=cutlass_cppgen.LayoutType.RowMajor, + element_accumulator=np.float32 + ) + + with ExpectException(True, + "RuntimeError: The epilogue consumes too much shared memory. " + "No valid tile description is found in the generator.", True): + plan.epilogue_visitor = epilogue_visitor + + def test_not_ssa(self): + """ + Test when the epilogue is not in SSA + """ + def evt_redefine(accum, C, alpha): + F = accum + C + F = F * alpha + D = F + return D, F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.5, + "D": self.fake_tensor(np.float16, (6, 512, 512)), + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Variable 'F' cannot be defined twice.", True): + cutlass_cppgen.epilogue.trace(evt_redefine, example_tensors) + + def evt_undefine(accum, alpha): + F = accum + C + D = F * alpha + return D, F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.5, + "D": self.fake_tensor(np.float16, (6, 512, 512)), + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Variable 'C' is undefined.", True): + cutlass_cppgen.epilogue.trace(evt_undefine, example_tensors) + + def test_missing_example_tensor(self): + """ + Test when the example tensor of an input/output variable is not provided + """ + def evt_missing_example_tensor(accum, C): + D = accum + C + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "RuntimeError: Example input for D is not provided.", True): + cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "RuntimeError: Example input for C is not provided.", True): + cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) + + def test_return_expression(self): + """ + Test when the return value is an expression + """ + def evt_return_expr(accum, C): + return accum + C + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "SyntaxError: Return value cannot be an expression", True): + cutlass_cppgen.epilogue.trace(evt_return_expr, example_tensors) + + def test_incompatible_shape(self): + """ + Test when the shape of example tensors are incompatible + """ + def evt_incompatible_shape(accum, C): + D = accum + C + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 256, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, + "RuntimeError: Dimension mismatch between accum(6, 256, 512), C(6, 512, 512).", True): + cutlass_cppgen.epilogue.trace(evt_incompatible_shape, example_tensors) + + def test_no_matching_impl(self): + def evt_no_matching_impl(accum, bias): + D = accum + reshape(permute(bias, indices=(1, 0)), new_shape=(512, 1)) + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 256)), + "bias": self.fake_tensor(np.float16, (16, 32)), + "D": self.fake_tensor(np.float16, (6, 512, 256)) + } + + with ExpectException(True, "NotImplementedError: No matching op for node bias with stride (0, (1, 32), 0).", True): + cutlass_cppgen.epilogue.trace(evt_no_matching_impl, example_tensors) + # + # Helper functions + # + + def fake_tensor(self, element, shape): + return Tensor(element=element, shape=shape, layout_tag=LayoutType.RowMajor) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..2913d5933f5342cc58b4f252657a724d2c7692da --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py @@ -0,0 +1,354 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests the high-level GEMM interface +""" + +from math import ceil +import unittest + +import cutlass_cppgen +import cutlass_cppgen.utils.datatypes as datatypes +from cutlass_cppgen.backend.utils.device import device_cc +from utils import ExpectException + + +class GemmEquivalence: + """ + Helper class for testing the equivalence of different constructions of the Gemm interface + """ + def __init__(self, element_A, element_B, element_C, element_D, element_accumulator, + layout_A, layout_B, layout_C, alignment_A, alignment_B, alignment_C): + self.element_A = element_A + self.element_B = element_B + self.element_C = element_C + self.element_D = element_D + self.element_accumulator = element_accumulator + self.layout_A = layout_A + self.layout_B = layout_B + self.layout_C = layout_C + self.alignment_A = alignment_A + self.alignment_B = alignment_B + self.alignment_C = alignment_C + self.plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C, + element_D=element_D, element_accumulator=element_accumulator, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C) + self.op = self.plan.construct(alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + + def _plans_equal(self, other_plan) -> bool: + """ + Compares whether two plans are equal + + :param other_plan: plan to compare against the default GEMM + :type other_plan: cutlass_cppgen.op.Gemm + + :return: whether `other_plan` is equivalent to `self.plan` + :rtype: bool + """ + other_op = other_plan.construct(alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C) + + # Compare whether the operations are equal by comparing the C++ code that would be emitted for them + return self.op.rt_module.emit() == other_op.rt_module.emit() + + def generic_test(self): + """ + Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types + and layouts for constructing the Gemm interface + """ + if not datatypes.is_numpy_available(): + return + + # Test when specifying all parameters + plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A + plan_other = cutlass_cppgen.op.Gemm(element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_B=self.layout_B, layout_C=self.layout_C, + element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + # Only run this test if the layouts and types for A and B are equal. + if self.element_A == self.element_B and self.layout_A == self.layout_B: + plan_other = cutlass_cppgen.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_C=self.layout_C, element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if self.element_C == self.element_accumulator: + plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, layout_A=self.layout_A, layout_B=self.layout_B, + layout_C=self.layout_C) + assert self._plans_equal(plan_other) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D + and self.element_A == self.element_accumulator and + self.layout_A == self.layout_B and self.layout_A == self.layout_C): + plan_other = cutlass_cppgen.op.Gemm(element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + def numpy_test(self): + """ + Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend + """ + if not datatypes.is_numpy_available(): + return + + import numpy as np + type_A = datatypes.numpy_type(self.element_A) + type_B = datatypes.numpy_type(self.element_B) + type_C = datatypes.numpy_type(self.element_C) + type_D = datatypes.numpy_type(self.element_D) + type_accum = datatypes.numpy_type(self.element_accumulator) + + layout_to_order = { + cutlass_cppgen.LayoutType.RowMajor: 'C', + cutlass_cppgen.LayoutType.ColumnMajor: 'F' + } + size = (2, 2) + A = np.zeros(size, order=layout_to_order[self.layout_A], dtype=type_A) + B = np.zeros(size, order=layout_to_order[self.layout_B], dtype=type_B) + C = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_C) + D = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_D) + + # Test when specifying all parameters via tensors + plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A as tensors + plan_np = cutlass_cppgen.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + # Only run this test if the layouts and types for A and B are equal. + if type_A == type_B and self.layout_A == self.layout_B: + plan_np = cutlass_cppgen.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A) + assert self._plans_equal(plan_np) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if type_C == type_accum: + plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D) + assert self._plans_equal(plan_np) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum and + self.layout_A == self.layout_B and self.layout_A == self.layout_C): + plan_np = cutlass_cppgen.op.Gemm(element=type_A, layout=self.layout_A) + assert self._plans_equal(plan_np) + + def test_all(self): + """ + Runs all tests on the Gemm interface + """ + self.generic_test() + self.numpy_test() + + +class GemmEquivalenceTest(unittest.TestCase): + """ + Tests the equivalence of different constructions of the Gemm interface + """ + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_8_8_8(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f32_ntn_8_8_8(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, + layout_A=cutlass_cppgen.LayoutType.ColumnMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.ColumnMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_4_4_4(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for F64 Tensor Core tests.") + def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f64, element_B=cutlass_cppgen.DataType.f64, element_C=cutlass_cppgen.DataType.f64, + element_D=cutlass_cppgen.DataType.f64, element_accumulator=cutlass_cppgen.DataType.f64, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.ColumnMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=1, alignment_B=1, alignment_C=1) + gemm_eq.test_all() + + +class GemmErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the high-level Gemm interface + """ + + def test_alignment(self): + """ + Tests case in which the alignment specified is unsupported + """ + plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + + with ExpectException(True, 'Alignment 16 is not supported for F16. The construction should fail.'): + op = plan.construct(alignment_A=16, alignment_B=16, alignment_C=16) + + def test_tensorop_availability(self): + """ + Tests case in which only SIMT operations are available but TensorOp is requested + """ + cc = device_cc() + + # F64 Tensor Core operations are only avaiable on certain devices + supports_tensorop_f64 = cc in [80, 89, 90] + plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f64, layout=cutlass_cppgen.LayoutType.RowMajor) + + error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}' + with ExpectException(not supports_tensorop_f64, error_msg): + plan.opclass = cutlass_cppgen.OpcodeClass.TensorOp + + expected_opclass = cutlass_cppgen.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass_cppgen.OpcodeClass.Simt + assert plan.opclass == expected_opclass, f'Expected opclass to be {expected_opclass}, but received {plan.opclass} for SM{cc}' + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for F16 Tensor Core tests.") + def test_opclass_switch(self): + """ + Tests cases in which the opcode class in question is switched (e.g., from TensorOp to SIMT) + """ + plan = cutlass_cppgen.op.Gemm( element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + assert plan.opclass == cutlass_cppgen.OpcodeClass.TensorOp + + # Ensure that all tile descriptions have opclass of TensorOp + for td in plan.tile_descriptions(): + assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.TensorOp + + plan.opclass = cutlass_cppgen.OpcodeClass.Simt + + # Ensure that all tile descriptions have opclass of Simt + for td in plan.tile_descriptions(): + assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.Simt + + def test_invalid_tile_description(self): + """ + Tests scenarios in which an invalid tile description is provided for a given CC + """ + cc = device_cc() + plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + td = plan.tile_descriptions()[0] + stages = td.stages + + # Zero stage count is valid for SM90+, as this is used to indicate that the builder's auto stage + # count should be used + with ExpectException(cc < 90, f'Requested zero stages'): + td.stages = 0 + plan.construct(td) + + if cc < 90: + with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): + td.stages = 3 + plan.construct(td) + elif cc == 90: + original_kschedule = td.kernel_schedule + original_eschedule = td.epilogue_schedule + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.NoSmemWarpSpecialized + td.stages = 3 + plan.construct(td) + # Reset schedules + td.kernel_schedule = original_kschedule + td.epilogue_schedule = original_eschedule + elif cc in [100, 101, 103]: + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.stages = 3 + plan.construct(td) + + with ExpectException(True, f'Requested too many stages'): + td.stages = 100 + plan.construct(td) + + # Reset stage count + td.stages = stages + + cluster_shape = td.cluster_shape + with ExpectException(cc < 90, f'Requested non-unit cluster shape on SM{cc}'): + td.cluster_shape = [2, 1, 1] + plan.construct(td) + + # Reset cluster shape + td.cluster_shape = cluster_shape + + with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized + plan.construct(td) + + with ExpectException(cc == 90, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto + plan.construct(td) + + with ExpectException(cc == 90, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized + plan.construct(td) + + with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative + td.tile_scheduler = cutlass_cppgen.TileSchedulerType.StreamK + plan.construct(td) + + # Ensure that all returned tile descriptions are unique + ops = {} + for i, td in enumerate(plan.tile_descriptions()): + op = plan.construct(td) + code_str = op.rt_module.emit() + if code_str in ops: + conflicting_td = ops[code_str] + assert False, f'Multiple tile descriptions emitted {code_str}\nTile descriptions are:\n{td}\n{conflicting_td}' + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9f93ca26e2d79a15dab4dd0045836ebd9fe62757 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py @@ -0,0 +1,69 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Helper functions & classes for interface test +""" +class ExpectException: + """ + Utility class to assert that an exception was raised when expected + + Example: + + .. highlight:: python + .. code-block:: python + + with ExceptionExpected(True, 'Division by zero'): + x = 1.0 / 0.0 + + :param exception_expected: whether an exception is expected to be raised + :type exception_expected: bool + :param message: message to print if an exception is raised when not expected or vice versa + :type message: str + """ + def __init__(self, exception_expected: bool, message: str = '', verify_msg=False): + self.exception_expected = exception_expected + self.message = message + self.verify_msg = verify_msg + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + exception_raised = exc_type is not None + assert self.exception_expected == exception_raised, self.message + if self.verify_msg: + exc_message = f"{exc_type.__name__}: {exc_val}" + assert exc_message == self.message, f"expect error message {self.message}, got {exc_message}" + + # Suppress the exception + return True diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cdc421ccffffeb7bd1696aaf9916330a6625ca --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py @@ -0,0 +1,75 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility script for discovering and running all PyCuTe tests +""" + +import argparse +import logging +import pathlib +import unittest + + +def numeric_log_level(log_level: str) -> int: + """ + Converts the string identifier of the log level into the numeric identifier used + in setting the log level + + :param x: string representation of log level (e.g., 'INFO', 'DEBUG') + :type x: str + + :return: numeric representation of log level + :rtype: int + """ + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f"Invalid log level: {log_level}") + return numeric_level + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, + help='Logging level to be used by the generator script') + args = parser.parse_args() + + # Set the logging level based on the user-provided `--log-level` command-line option + logging.basicConfig(level=args.log_level) + + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, "test_*.py") + test_runner = unittest.runner.TextTestRunner() + results = test_runner.run(tests) + if not results.wasSuccessful(): + raise Exception("Test cases failed") diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py new file mode 100644 index 0000000000000000000000000000000000000000..d4330377cab7079ea16422f194ddf4f2403ea507 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py @@ -0,0 +1,95 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.coalesce +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestCoalesce(unittest.TestCase): + def helper_test_coalesce(self, layout): + layoutR = coalesce(layout) + + _LOGGER.debug(f"{layout} => {layoutR}") + + self.assertEqual(size(layoutR), size(layout)) + + for i in range(size(layout)): + self.assertEqual(layoutR(i), layout(i)) + + def test_coalesce(self): + layout = Layout(1,0) + self.helper_test_coalesce(layout) + + layout = Layout(1,1) + self.helper_test_coalesce(layout) + + layout = Layout((2,4)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6), (1,6,2)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,6), (1,7,2)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,6), (4,7,8)) + self.helper_test_coalesce(layout) + + layout = Layout((2,(4,6))) + self.helper_test_coalesce(layout) + + layout = Layout((2,4), (4,1)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6), (24,6,1)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,3), (2,4,4)) + self.helper_test_coalesce(layout) + + layout = Layout(((2,2),(2,2)), ((1,4),(8,32))) + self.helper_test_coalesce(layout) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8684a55b19c90eae11ddd1cca011c2ff8270b5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py @@ -0,0 +1,92 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.complement +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestComplement(unittest.TestCase): + def helper_test_complement(self, layout): + layoutR = complement(layout) + + _LOGGER.debug(f"{layout} => {layoutR}") + + # Post-condition: test disjointness of the codomains + for a in range(size(layout)): + for b in range(size(layoutR)): + assert (layout(a) != layoutR(b)) or (layout(a) == 0 and layoutR(b) == 0) + + def test_complement(self): + test = Layout(1,0) + self.helper_test_complement(test) + + test = Layout(1,1) + self.helper_test_complement(test) + + test = Layout(4,0) + self.helper_test_complement(test) + + test = Layout((2,4),(1,2)) + self.helper_test_complement(test) + + test = Layout((2,3),(1,2)) + self.helper_test_complement(test) + + test = Layout((2,4),(1,4)) + self.helper_test_complement(test) + + test = Layout((2,4,8),(8,1,64)) + self.helper_test_complement(test) + + test = Layout(((2,2),(2,2)),((1,4),(8,32))) + self.helper_test_complement(test) + + test = Layout((2,(3,4)),(3,(1,6))) + self.helper_test_complement(test) + + test = Layout((4,6),(1,6)) + self.helper_test_complement(test) + + test = Layout((4,10),(1,10)) + self.helper_test_complement(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py new file mode 100644 index 0000000000000000000000000000000000000000..6c27eb7fe6cbb7bbbea7bd644ac8e64a2fc853c9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py @@ -0,0 +1,213 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.composition +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestComposition(unittest.TestCase): + def helper_test_composition(self, layoutA, layoutB): + layoutR = composition(layoutA, layoutB) + + _LOGGER.debug(f"{layoutA} o {layoutB} => {layoutR}") + + # True post-condition: Every coordinate c of layoutB with L1D(c) < size(layoutR) is a coordinate of layoutR. + + # Test that R(c) = A(B(c)) for all coordinates c in layoutR + for i in range(size(layoutR)): + self.assertEqual(layoutR(i), layoutA(layoutB(i))) + + def test_composition(self): + layoutA = Layout(1,0) + layoutB = Layout(1,0) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,0) + layoutB = Layout(1,1) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,1) + layoutB = Layout(1,0) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,1) + layoutB = Layout(1,1) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (0)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4), (0)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((1), (0)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((1), (0)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((2), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((2), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12), (2)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((4,3), (3,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12), (2)) + layoutB = Layout((4,3), (3,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((2,3), (2,4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((12)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((6), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((6,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((12)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((6), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((6,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((8,8)) + layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((8,8), (8,1)) + layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + layoutB = Layout(8, 4) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(((4,2)), ((1,16))) + layoutB = Layout((4,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((2,2), (2,1)) + layoutB = Layout((2,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2)) + layoutB = Layout((2,2,2), (2,8,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2), (2,8,1)) + layoutB = Layout((2,2,2), (1,8,2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2), (2,8,1)) + layoutB = Layout((4,2,2), (2,8,1)) + self.helper_test_composition(layoutA, layoutB) + + # Pre-coalesced LHS + layoutA = Layout((4,6,8),(1,4,7)) + layoutB = Layout((6),(1)) + self.helper_test_composition(layoutA, layoutB) + + # Mid-layout truncation + layoutA = Layout((4,6,8,10),(2,3,5,7)) + layoutB = Layout(6,12) + self.helper_test_composition(layoutA, layoutB) + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbf443c9725735b0051d0a225a55eece9c663a8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py @@ -0,0 +1,80 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.int_tuple +""" + +import unittest + +from pycute import * + + +class TestIntTuple(unittest.TestCase): + def test_product(self): + self.assertEqual(product(2), 2) + + self.assertEqual(product((3,2)), 6) + + self.assertEqual(product(product(((2,3),4))), 24) + + def test_inner_product(self): + self.assertEqual(inner_product(2, 3), 6) + + self.assertEqual(inner_product((1,2), (3,2)), 7) + + self.assertEqual(inner_product(((2,3),4), ((2,1),2)), 15) + + def test_shape_div(self): + self.assertEqual(shape_div((3,4), 6), (1,2)) + + self.assertEqual(shape_div((3,4), 12), (1,1)) + + self.assertEqual(shape_div((3,4), 36), (1,1)) + + self.assertEqual(shape_div(((3,4),6), 36), ((1,1),2)) + + self.assertEqual(shape_div((6,(3,4)), 36), (1,(1,2))) + + def test_prefix_product(self): + self.assertEqual(prefix_product(2), 1) + + self.assertEqual(prefix_product((3,2)), (1,3)) + + self.assertEqual(prefix_product((3,2,4)), (1,3,6)) + + self.assertEqual(prefix_product(((2,3),4)), ((1,2),6)) + + self.assertEqual(prefix_product(((2,3),(2, 1, 2),( 5, 2, 1))), + ((1,2),(6,12,12),(24,120,240))) + + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..a6501fd6c7c6fc5a518e4d22bf93dc0e4746a8ba --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py @@ -0,0 +1,87 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.left_inverse +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestLeftInverse(unittest.TestCase): + def helper_test_left_inverse(self, layout): + inv_layout = left_inverse(layout) + + _LOGGER.debug(f"{layout} => {inv_layout}") + + for i in range(size(layout)): + self.assertEqual(inv_layout(layout(i)), i) + + def test_left_inverse(self): + test = Layout(1,0) + self.helper_test_left_inverse(test) + + test = Layout((1,1),(0,0)) + self.helper_test_left_inverse(test) + + test = Layout(1,1) + self.helper_test_left_inverse(test) + + test = Layout(4,1) + self.helper_test_left_inverse(test) + + test = Layout(4,2) + self.helper_test_left_inverse(test) + + test = Layout((8,4),(1,8)) + self.helper_test_left_inverse(test) + + test = Layout((8,4),(4,1)) + self.helper_test_left_inverse(test) + + test = Layout((2,4,6),(1,2,8)) + self.helper_test_left_inverse(test) + + test = Layout((2,4,6),(4,1,8)) + self.helper_test_left_inverse(test) + + test = Layout((4,2),(1,16)) + self.helper_test_left_inverse(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed9759d7808da8087fe9c76761d2dd9eaeab08b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py @@ -0,0 +1,96 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.left_inverse +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestRightInverse(unittest.TestCase): + def helper_test_right_inverse(self, layout): + inv_layout = right_inverse(layout) + + _LOGGER.debug(f"{layout} => {inv_layout}") + + for i in range(size(inv_layout)): + self.assertEqual(layout(inv_layout(i)), i) + + def test_right_inverse(self): + test = Layout(1,0) + self.helper_test_right_inverse(test) + + test = Layout((1,1),(0,0)) + self.helper_test_right_inverse(test) + + test = Layout((3,7),(0,0)) + self.helper_test_right_inverse(test) + + test = Layout(1,1) + self.helper_test_right_inverse(test) + + test = Layout(4,0) + self.helper_test_right_inverse(test) + + test = Layout(4,1) + self.helper_test_right_inverse(test) + + test = Layout(4,2) + self.helper_test_right_inverse(test) + + test = Layout((2,4),(0,2)) + self.helper_test_right_inverse(test) + + test = Layout((8,4),(1,8)) + self.helper_test_right_inverse(test) + + test = Layout((8,4),(4,1)) + self.helper_test_right_inverse(test) + + test = Layout((2,4,6),(1,2,8)) + self.helper_test_right_inverse(test) + + test = Layout((2,4,6),(4,1,8)) + self.helper_test_right_inverse(test) + + test = Layout((4,2),(1,16)) + self.helper_test_right_inverse(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb99a4833529e18fa22d65a235ce80dad372365 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py @@ -0,0 +1,59 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.typing +""" + +import logging +import unittest +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestTyping(unittest.TestCase): + def helper_test_typing(self, _cls, _obj, cls, expected: bool): + _LOGGER.debug(f"issubclass({_cls}, {cls})") + _LOGGER.debug(f"isinstance({_obj}, {cls})") + + self.assertEqual(expected, issubclass(_cls, cls)) + self.assertEqual(expected, isinstance(_obj, cls)) + + def test_typing(self): + self.helper_test_typing(int, 1, Integer, True) + self.helper_test_typing(float, 1., Integer, False) + self.helper_test_typing(str, 'hi', Integer, False) + self.helper_test_typing(bool, False, Integer, False) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h new file mode 100644 index 0000000000000000000000000000000000000000..86b7823785a9f2a957cf505740d6cfde45ccfef1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h @@ -0,0 +1,102 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#pragma warning (disable : 4068 ) /* disable unknown pragma warnings for visual studio */ + +#pragma nv_diag_suppress boolean_controlling_expr_is_constant +#include +#pragma nv_diag_warning boolean_controlling_expr_is_constant +#pragma warning( disable : 4503) + +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Gets a CUDA device +cudaDeviceProp GetCudaDevice(); + +/// Prints device properties +std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &device); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Sets flags for Unit test +void FilterArchitecture(); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order +// of problem sizes run by CUTLASS unit tests +int CutlassUnitTestProblemCount(); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// active test macro +#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ + TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__ + +// disabled test macro +#define CUTLASS_TEST_LEVEL_DISABLED(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ + TEST(NAME_STATIC,DISABLED_L##LEVEL##_##NAME_DYNAMIC) {} + +#if CUTLASS_TEST_LEVEL == 0 +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#elif CUTLASS_TEST_LEVEL == 1 +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#else +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#endif + +#if !defined(CUTLASS_TEST_UNIT_ENABLE_WARNINGS) +#define CUTLASS_TEST_UNIT_ENABLE_WARNINGS false +#endif + +#if (__CUDACC_VER_MAJOR__ >= 12) + #define CUDA_12_0_SM90_FEATURES_SUPPORTED true +#else + #define CUDA_12_0_SM90_FEATURES_SUPPORTED false +#endif + +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h new file mode 100644 index 0000000000000000000000000000000000000000..3035e9862bcb79b749b4cbc4a74341bceac9c598 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h @@ -0,0 +1,907 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Helper to construct cached name for +*/ +#pragma once + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "thrust/universal_vector.h" + +#ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS +#define CUTLASS_TEST_ENABLE_CACHED_RESULTS false +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result of a test +struct CachedTestKey { + + std::string op; ///< Concatenated string representation of operation performed + std::string problem; ///< Concatenated string representation of problem description + std::string types; ///< Concatenated string representation of operand types + uint32_t A; ///< Hashed result of tensor A + uint32_t B; ///< Hashed result of tensor B + uint32_t C; ///< Hashed result of tensor C + + // + // Methods + // + inline CachedTestKey(): A(), B(), C() { } + + inline CachedTestKey( + std::string op, ///< Concatenated string representation of operation performed + std::string problem, ///< Concatenated string representation of problem description + std::string types, ///< Concatenated string representation of operand types + uint32_t A, ///< Hashed result of tensor A + uint32_t B, ///< Hashed result of tensor B + uint32_t C ///< Hashed result of tensor C + ): + op(op), problem(problem), types(types), A(A), B(B), C(C) + { } + + /// Checks for equality of the problem + bool operator==(CachedTestKey const &rhs) const { + return op == rhs.op && problem == rhs.problem && types == rhs.types && A == rhs.A && B == rhs.B && C == rhs.C; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::istream &operator>>(std::istream &in, CachedTestKey &result) { + + in >> result.op; + in >> result.problem; + in >> result.types; + in >> result.A; + in >> result.B; + in >> result.C; + + return in; +} + +inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) { + + out << result.op << " "; + out << result.problem << " "; + out << result.types << " "; + out << result.A << " "; + out << result.B << " "; + out << result.C << " "; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct CachedTestResult { + uint32_t D; + // + // Methods + // + + CachedTestResult(): D() + { } + + CachedTestResult(uint32_t D): D(D) + { } + + operator bool() const { + return bool(D); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::istream &operator>>(std::istream &in, CachedTestResult &result) { + in >> result.D; + return in; +} + +inline std::ostream &operator<<(std::ostream &out, CachedTestResult const &result) { + out << result.D; + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct CachedTestResultListing { + + std::list> results; + + // + // Methods + // + + inline CachedTestResultListing(std::string const &path) { + std::ifstream file(path); + + while (file.good()) { + CachedTestKey key; + file >> key; + + CachedTestResult result; + file >> result; + + if (result) { + results.push_back(std::make_pair(key, result)); + } + } + } + + /// Returns the cached result + std::pair find(CachedTestKey const &rhs) const { + for (auto const & result : results) { + if (result.first == rhs) { + return std::make_pair(true, result.second); + } + } + return std::make_pair(false, CachedTestResult()); + } + + /// Appends an entry + void append(CachedTestKey const &key, CachedTestResult const &result) { + if (result) { + results.push_back(std::make_pair(key, result)); + } + } + + /// Writes the entire listing to a file + bool write(std::string const &path) { + std::ofstream file(path); + if (!file.good()) { + return false; + } + + for (auto const &result : results) { + file << result.first << result.second << std::endl; + } + + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ScalarEncoder { + Element scalar; + + ScalarEncoder(Element s): scalar(s) { } + + std::string str() const { + std::stringstream ss; + Element s = scalar; + if (s < Element()) { + s = -s; + ss << "n"; + } + ss << s; + return ss.str(); + } +}; + +template +ScalarEncoder EncodeScalar(Element a) { + return ScalarEncoder(a); +} + +template +struct ScalarEncoder> { + cutlass::complex scalar; + + ScalarEncoder(cutlass::complex s): scalar(s) { } + + std::string str() const { + std::stringstream ss; + ss << EncodeScalar(scalar.real()) << "_" << EncodeScalar(scalar.imag()) << "i"; + return ss.str(); + } +}; + +template +std::ostream &operator<<(std::ostream &out, ScalarEncoder const &scalar) { + out << scalar.str(); + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline char const *EncodeOperator(cutlass::conv::Operator conv_op) { + switch (conv_op) { + case cutlass::conv::Operator::kFprop: return "fprop"; + case cutlass::conv::Operator::kDgrad: return "dgrad"; + case cutlass::conv::Operator::kWgrad: return "wgrad"; + case cutlass::conv::Operator::kDeconv: return "deconv"; + } + return "conv_unknown"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Encode GemmCoord (Gemm problem size) +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::gemm::GemmCoord const &problem) { + + out << problem.m() << "x" << problem.n() << "x" << problem.k() << "_"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Encode Conv2dProblemSize +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::conv::Conv2dProblemSize const &problem) { + + out << problem.N << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" + << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; + + out << "pad_h" << problem.pad_h << "w" << problem.pad_w << "_"; + out << "stride_h" << problem.stride_h << "w" << problem.stride_w << "_"; + out << "dil_h" << problem.dilation_h << "w" << problem.dilation_w << "_"; + + switch (problem.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Encode Conv3dProblemSize +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::conv::Conv3dProblemSize const &problem) { + + out << problem.N << "x" << problem.D << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" + << problem.Z << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; + + out << "pad_d" << problem.pad_h << "h" << problem.pad_h << "w" << problem.pad_w << "_"; + out << "stride_d" << problem.stride_d << "h" << problem.stride_h << "w" << problem.stride_w << "_"; + out << "dil_d" << problem.dilation_d << "h" << problem.dilation_h << "w" << problem.dilation_w << "_"; + + switch (problem.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Encode 3.x ConvNd ProblemShape +template +inline std::ostream &EncodeProblemSize( + std::ostream &out, + ProblemShape const& problem_shape) { + + out << problem_shape.shape_A << "_"; + out << problem_shape.shape_B << "_"; + + out << "padl" << problem_shape.lower_padding << "_"; + out << "padu" << problem_shape.upper_padding << "_"; + out << "str" << problem_shape.traversal_stride << "_"; + out << "dil" << problem_shape.dilation << "_"; + + switch (problem_shape.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string ElementTypeName() { + return std::string(typeid(Element).name()); +} + +template <> +inline std::string ElementTypeName() { + return "h"; +} + +template <> +inline std::string ElementTypeName>() { + return "ch"; +} + +template <> +inline std::string ElementTypeName() { + return "bf16"; +} + +template <> +inline std::string ElementTypeName>() { + return "cbf16"; +} + +template <> +inline std::string ElementTypeName() { + return "tf32"; +} + +template <> +inline std::string ElementTypeName>() { + return "ctf32"; +} + +template <> +inline std::string ElementTypeName>() { + return "c"; +} + +template <> +inline std::string ElementTypeName>() { + return "z"; +} + +template <> +inline std::string ElementTypeName>() { + return "q"; +} + +template <> +inline std::string ElementTypeName() { + return "s8"; +} + +template <> +inline std::string ElementTypeName() { + return "u8"; +} + +template <> +inline std::string ElementTypeName() { + return "s4"; +} + +template <> +inline std::string ElementTypeName() { + return "u4"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string LayoutTypeName() { + return std::string(typeid(Layout).name()); +} + +template <> +inline std::string LayoutTypeName() { + return "n"; +} + +template <> +inline std::string LayoutTypeName() { + return "t"; +} + +template <> +inline std::string LayoutTypeName() { + return "nhwc"; +} + +template <> +inline std::string LayoutTypeName>() { + return "nc32hw32"; +} + +template <> +inline std::string LayoutTypeName>() { + return "nc64hw64"; +} + +template <> +inline std::string LayoutTypeName>() { + return "c32rsk32"; +} + +template <> +inline std::string LayoutTypeName>() { + return "c64rsk64"; +} + +template <> +inline std::string LayoutTypeName() { + return "ndhwc"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string TensorTypeName() { + std::stringstream ss; + ss << ElementTypeName() << LayoutTypeName(); + return ss.str(); +} + +template +inline std::string TensorTypeName() { + std::stringstream ss; + ss << ElementTypeName(); + return ss.str(); +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function on a byte array +struct CRC32 { + + uint32_t table[256]; + + // + // Methods + // + + CRC32() { + + uint32_t rem; + int i, j; + + for (i = 0; i < 256; i++) { + rem = i; + for (j = 0; j < 8; j++) { + if (rem & 1) { + rem >>= 1; + rem ^= 0xedb88320; + } else + rem >>= 1; + } + table[i] = rem; + } + } + + /// Computes the CRC of an array of bytes + uint32_t operator()(void const *start, size_t length, uint32_t crc = uint32_t()) const { + uint8_t const *p = static_cast(start); + uint8_t const *q = static_cast(start) + length; + + crc = ~crc; + + for (; p != q; ++p) { + uint8_t octet = *p; + crc = (crc >> 8) ^ table[(crc & 0xff) ^ octet]; + } + + return ~crc; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Element, typename Layout +> +uint32_t TensorHash( + cutlass::TensorView view, + CRC32 const &hash = CRC32(), + uint32_t crc = uint32_t() +) { + + return hash(view.data(), view.capacity() * cutlass::sizeof_bits::value / 8, crc); +} + +template +uint32_t TensorHash( + thrust::universal_vector& tensor, + CRC32 const &hash = CRC32(), + uint32_t crc = uint32_t() +) { + + return hash(tensor.data().get(), tensor.size() * cutlass::sizeof_bits::value / 8, crc); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline std::ostream &EncodeTypes( + std::ostream &out +) { + + out << TensorTypeName() << "_" + << TensorTypeName() << "_" + << TensorTypeName() << "_" + << ElementTypeName() << "_" + << ElementTypeName(); + + return out; +} + +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementD +> +inline std::ostream &EncodeTypes( + std::ostream &out +) { + + out << TensorTypeName() << "_" + << TensorTypeName() << "_" + << TensorTypeName() << "_" + << ElementTypeName(); + + return out; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedGemmTestKey( + cutlass::gemm::GemmCoord const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode gemm operator and problem sizes + key.op = "gemm"; + + std::stringstream ss_problem; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d_with_broadcast"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dWithReductionTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d_with_reduction"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv3dTestKey( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv3dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv3d operator and problem sizes + key.op = "conv3d"; + + std::stringstream ss_problem; + + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape, + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementD +> +inline CachedTestKey CreateCachedConvNd3xTestKey( + cutlass::conv::Operator conv_operator, + ProblemShape const& problem_shape, + double alpha, + double beta, + thrust::universal_vector A, + thrust::universal_vector B, + thrust::universal_vector C +) { + + CachedTestKey key; + + // Encode convNd operator and problem sizes + std::stringstream ss_op; + ss_op << "conv" << ProblemShape::RankS << "d"; + key.op = ss_op.str(); + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem_shape); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, + ElementB, + ElementC, + ElementD>(ss_types); + key.types = ss_types.str(); + + // Encode problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace test::conv::device + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h new file mode 100644 index 0000000000000000000000000000000000000000..a14134b2854732e669977831207a456d28beed9f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h @@ -0,0 +1,927 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed sizes for Conv2d problem +*/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +namespace test { +namespace conv { +namespace device { + +using Conv2dProblemVector = std::vector; + +// +// Structures to prune items from Conv2dProblemVector +// +// Specification template for pruning items for convolution problem lists +template struct Specification +{ + virtual ~Specification() = default; + virtual bool is_satisfied(T item) const = 0; +}; + +// input size (NHWC) specification +struct InputSizeSpecification : Specification +{ + cutlass::Tensor4DCoord input_size; + + InputSizeSpecification(cutlass::Tensor4DCoord input_size_) : input_size(input_size_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((input_size.n() == item.N) && (input_size.h() == item.H) && (input_size.w() == item.W) && (input_size.c() == item.C)); + } +}; + +// stride (stride_h, stride_w) specification +struct StrideSpecification : Specification +{ + cutlass::MatrixCoord stride; + + StrideSpecification(cutlass::MatrixCoord stride_) : stride(stride_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((stride.row() == item.stride_h) && (stride.column() == item.stride_h)); + } +}; + +// channel (C,K) specification, must be multiple of minimum channel +struct ChannelDivisibilitySpecification : Specification +{ + int channel_multiple; + + ChannelDivisibilitySpecification(int channel_multiple_) : channel_multiple(channel_multiple_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((item.K % channel_multiple == 0) && (item.C % channel_multiple == 0)); + } +}; + +// +// Pruning function for items from Conv2dProblemVector based on a Specification +// +inline Conv2dProblemVector prune(Conv2dProblemVector const &items, + Specification const &spec) +{ + Conv2dProblemVector pruned_list; + + for (auto& p : items) + if (spec.is_satisfied(p)) + pruned_list.push_back(p); + return pruned_list; +} + + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedConv2dProblemSizes initializes and holds conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedConv2dProblemSizes { + + // + // Data members + // + int minimum_channel_size; + + Conv2dProblemVector conv2d_default_sizes; + Conv2dProblemVector conv2d_rigorous_sizes; + Conv2dProblemVector conv2d_resnet50_sizes; + Conv2dProblemVector conv2d_resnet50_sizes_perf; + + // + // Methods + // + /// Default ctor + TestbedConv2dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { + initialize_conv2d_default_sizes(); + initialize_conv2d_rigorous_sizes(); + initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes, 1 /*batch-size*/); + + initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes_perf, 34 /*batch-size*/); + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv2dProblemVector *problems_vectors[] = { + &conv2d_default_sizes, + &conv2d_rigorous_sizes, + &conv2d_resnet50_sizes, + &conv2d_resnet50_sizes_perf + }; + + for (Conv2dProblemVector *problems : problems_vectors) { + Conv2dProblemVector filtered; + + for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { + if (!(problem.C % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_conv2d_default_sizes() { + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (1,1) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 1, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 8, minimum_channel_size}, // input size (NHWC) + {8, 1, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 8, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 4, 4, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {2, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 5, 5, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 5, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 6, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 7, 7, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (1,1) asymmetric paddings (1, 0, 1, 0) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 1, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 8, minimum_channel_size}, // input size (NHWC) + {8, 1, 3, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 8, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 4, 4, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {2, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 5, 5, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 5, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 6, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 7, 7, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (2,2) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 11, 7, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 11, 7, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 11, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 17, 19, minimum_channel_size}, // input size (NHWC) + {16, 2, 2, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 5, minimum_channel_size}, // input size (NHWC) + {16, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 17, 8}, // input size (NHWC) + {24, 3, 3, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 21, 8}, // input size (NHWC) + {24, 3, 3, 8}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 20, 24, 8}, // input size (NHWC) + {40, 3, 3, 8}, // filter size (KRSC) + {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 15, 19, 160}, // input size (NHWC) + {224, 1, 1, 160}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 19, 37, 160}, // input size (NHWC) + {224, 3, 3, 160}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, 160}, // input size (NHWC) + {224, 2, 3, 160}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 21, 128}, // input size (NHWC) + {224, 3, 3, 128}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 29, 37, 160}, // input size (NHWC) + {224, 5, 5, 160}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 15, 19, 32 + minimum_channel_size}, // input size (NHWC) + {96, 3, 3, 32 + minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 24, 64 + minimum_channel_size}, // input size (NHWC) + {96, 3, 3, 64 + minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 16, 288}, // input size (NHWC) + {160, 5, 5, 288}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 55, 51, 256}, // input size (NHWC) + {512, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 71, 80, 32}, // input size (NHWC) + {64, 5, 5, 32}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 224, 224, 8}, // input size (NHWC) + {64, 7, 7, 8}, // filter size (KRSC) + {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size stride (3, 3), filter (3, 3), non-default padding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 23, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size padding > stride, asymmetric filter, padding and striding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 31, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {5, 5, 7, 7}, // padding (pad_h, _, pad_w, _) + {3, 4}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 35, 256}, // input size (NHWC) + {512, 7, 5, 256}, // filter size (KRSC) + {11, 11, 7, 7}, // padding (pad_h, _, pad_w, _) + {3, 5}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size *mixed* stride (1, 2) and (2, 1), + // filter (3, 3), default padding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + ///////////////////////////////////////////////////////////////////////////// + // Additional input size + ///////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 28, 28, 256}, // input size (NHWC) + {256, 2, 2, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 32, 32, 16}, // input size (NHWC) + {32, 3, 3, 16}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {6, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {32, 24, 32, 32}, // input size (NHWC) + {32, 1, 2, 32}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {4, 4, 5, 128}, // input size (NHWC) + {256, 3, 6, 128}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + {4, 3, 3, 256} // output size (NPQK) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {4, 2, 3, 256}, // input size (NHWC) + {328, 3, 5, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + {4, 1, 1, 328} // output size (NPQK) + )); + } + + + // Add a few large and rigorous convolution problem sizes + void initialize_conv2d_rigorous_sizes() { + +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 124, 224, 96}, // input size (NHWC) + {24, 7, 7, 96}, // filter size (KRSC) + {1, 229, 129, 32} // output size (NPQK) + )); + + conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 233, 35, 48}, // input size (NHWC) + {24, 7, 5, 48}, // filter size (KRSC) + {1, 233, 35, 24} // output size (NPQK) + )); + +#endif + + } + + + // Add resent50 layers to unit testing sizes + void initialize_conv2d_resnet50_sizes(Conv2dProblemVector &conv2d_problem_vector, int batch_size = 1){ + +#if 0 // Resnet50 first layer (layer_id = 0) with channel = 3 is not supported in cutlass + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + [1, 224, 224, 3], // input size (NHWC) + [64, 7, 7, 3], // filter size (KRSC) + [3, 3, 3, 3], // padding (pad_h, _, pad_w, _) + [2, 2], // stride (stride_h, stride_w) + [1, 1], // dilation (dilation_h, dilation_w) + )); +#endif + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {256, 1, 1, 64}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {64, 1, 1, 64}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {64, 3, 3, 64}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {64, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {512, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {128, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 128}, // input size (NHWC) + {128, 3, 3, 128}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 128}, // input size (NHWC) + {512, 1, 1, 128}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {128, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {1024, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {256, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 256}, // input size (NHWC) + {256, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 256}, // input size (NHWC) + {1024, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {256, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {2048, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {512, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 512}, // input size (NHWC) + {512, 3, 3, 512}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 512}, // input size (NHWC) + {2048, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 2048}, // input size (NHWC) + {512, 1, 1, 2048}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + } + +}; + + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedGroupConv2dProblemSizes initializes and holds group conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedGroupConv2dProblemSizes { + + // + // Data members + // + int threadblock_n; + int threadblock_k; + int minimum_channel_size; + + Conv2dProblemVector default_single_group_sizes; + Conv2dProblemVector default_multiple_group_sizes; + + // + // Methods + // + /// Default ctor + TestbedGroupConv2dProblemSizes( + int threadblock_n_, + int threadblock_k_, + int minimum_channel_size_ = 64) + : threadblock_n (threadblock_n_), + threadblock_k (threadblock_k_), + minimum_channel_size (minimum_channel_size_) { + initialize_group_conv2d_default_sizes(); + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv2dProblemVector *problems_vectors[] = { + &default_single_group_sizes, + &default_multiple_group_sizes + }; + + for (Conv2dProblemVector *problems : problems_vectors) { + Conv2dProblemVector filtered; + + for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { + if (!((problem.C / problem.groups) % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_group_conv2d_default_sizes() { + + //////////////////////////////////////////////////////////////////////////////////// + // One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 + // One CTA calculates a single group + //////////////////////////////////////////////////////////////////////////////////// + + for (int cta_per_group_k = 1; cta_per_group_k < 4; ++cta_per_group_k) { + // groups = 2, 3, 4 + for (int groups = 2; groups < 5; ++groups) { + + int conv_k = cta_per_group_k * threadblock_n * groups; + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 2 * groups}, // input size (NHWC) + {conv_k, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + groups // groups + )); + + } // loop groups + } // loop cta_per_group_k + + // Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k}, // input size (NHWC) + {threadblock_n * 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // Larger problem sizes + + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 696}, // input size (NHWC) + {768, 3, 3, 232}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 3 // groups + )); + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 14, 14, 1392}, // input size (NHWC) + {1536, 3, 3, 232}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 3 // groups + )); + + //////////////////////////////////////////////////////////////////////////////////// + // One CTA calculate multiple groups: CTA::N % k_per_group = 0 + //////////////////////////////////////////////////////////////////////////////////// + + // 2 groups per CTA + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 4}, // input size (NHWC) + {threadblock_n, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // 2 groups per CTA and partial gemm_k + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k}, // input size (NHWC) + {threadblock_n, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // 4 groups per CTA + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 8}, // input size (NHWC) + {threadblock_n / 2, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 4 // groups + )); + + // 4 groups per CTA and partial gemm_k + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 2}, // input size (NHWC) + {threadblock_n / 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 4 // groups + )); + } + +}; + + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..34588ecb467b824cc0fcbbff0bc0d99e4385d80e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h @@ -0,0 +1,818 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv2d { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + int tested_problem_count; + +public: + + TestbedConv2d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // increment tested problem count run by the testbed + tested_problem_count++; + +#if 0 // display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run." << std::endl; + return false; + } + + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv2dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + std::stringstream ss_problem_size_text; + ss_problem_size_text << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << ss_problem_size_text.str() + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSpecificConv2d( + const Conv2dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2d testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllConv2d( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + // Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes) + std::vector problem_vectors = { + conv_test_sizes, // run user specified sizes + conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + //conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Flatten 2D problem_vectors into a 1D problem_sizes + std::vector problem_sizes; + for (auto problem_vector : problem_vectors) { + for(auto conv_problem : problem_vector) { + problem_sizes.push_back(conv_problem); + } + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient) + // run the most rigorous problem size first + if (CutlassUnitTestProblemCount()) { + std::reverse(problem_sizes.begin(), problem_sizes.end()); + } + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // Fixed channels algorithm requires channel count to match access size + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFixedChannels) { + if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { + continue; + } + } + + // Few channels algorithm requires channel count to match access size + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFewChannels) { + if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { + continue; + } + } + + // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} + // Although strided dgrad works for all stride combinations, we are only going + // to run strided dgrad for non-unity strides + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts + if (CutlassUnitTestProblemCount() && + testbed.tested_problem_count > CutlassUnitTestProblemCount()) { + return true; + } + } + + // Small-channels convolution can't run here. + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFixedChannels) { + + return true; + } + + // Small-channels convolution can't run here. + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFewChannels) { + + return true; + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}) // dilation (dilation_h, dilation_w) + .reset_split_k_slices(2), + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts + if (CutlassUnitTestProblemCount() && + testbed.tested_problem_count > CutlassUnitTestProblemCount()) { + return true; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..cf075674da673cf8e056172732f912b8acba3c5b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h @@ -0,0 +1,666 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/host_reorder.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class InterleavedTestbedConv2d { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_B_reordered; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + InterleavedTestbedConv2d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 3; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_B_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + cutlass::reorder_convK( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv2dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementC, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << "ncxhwx_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_cxrskx_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllInterleavedConv2d( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + InterleavedTestbedConv2d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(InterleavedK); // minimum channel size must be multiple of InterleavedK for interleaved layout + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + ChannelDivisibilitySpecification channel_spec(InterleavedK); //input and output channels must be multiple of InterleavedK + auto pruned_problem_vector = prune(*problem_vector, channel_spec); + + // Run conv testbed on default convolution sizes + for(auto conv_problem : pruned_problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's unity stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + +#if 0 + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } +#endif + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..ad7b2ce61a66a79f852c0aac0895d10ba18e5466 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h @@ -0,0 +1,622 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed for running device-level Conv2Ds with absolute maximum calculation and scaling +*/ + +#pragma once + +#include +#include +#include + +#include "conv2d_problems.h" +#include "../../common/cutlass_unit_test.h" +#include "../../gemm/device/testbed_utils.h" + +#include "cutlass/matrix_coord.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_reduce.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Conv, + template class ActivationFunctor +> +struct TestbedConv2dWithAbsMax { + + using ElementAccumulator = typename Conv::ElementAccumulator; + using ElementCompute = typename Conv::UnderlyingKernel::Epilogue::OutputOp::ElementCompute; + using ElementScalingFactor = typename Conv::EpilogueOutputOp::ElementScalingFactor; + using ElementAbsmax = typename Conv::EpilogueOutputOp::ElementAbsmax; + static cutlass::conv::Operator const kConvolutionalOperator = Conv::kConvolutionalOperator; + + static bool const kScaleAux = Conv::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; + static bool const kScaleOutput = Conv::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; + bool doScaleA; + bool doScaleB; + bool doScaleC; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_Aux; + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_Vector; + cutlass::HostTensor tmp_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // + // Methods + // + + TestbedConv2dWithAbsMax( + bool scaleA = true, + bool scaleB = true, + bool scaleC = true, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize scaling factors + template + bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { + cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); + return true; + } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::conv::Conv2dProblemSize const &problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Vector.resize({1, 1, 1, implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()}); + reference_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + tmp_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<4> origin(0); + tensor_A.host_view().at(origin) = typename Conv::ElementA(1); + tensor_B.host_view().at(origin) = typename Conv::ElementB(1); + tensor_C.host_view().at(origin) = typename Conv::ElementC(1); + tensor_Vector.host_view().at(origin) = typename Conv::ElementC(1); + + cutlass::reference::host::TensorFill(tensor_D.host_view()); + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_Vector.sync_device(); + + int scale_bits = 2; + if (doScaleA) { + scale_A.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits)); + scale_A.sync_device(); + } + + if (doScaleB) { + scale_B.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits)); + scale_B.sync_device(); + } + + if (doScaleC) { + scale_C.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits)); + scale_C.sync_device(); + } + + if (kScaleOutput) { + scale_D.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits)); + scale_D.sync_device(); + + abs_max_D.resize({1, 1, 1, 1}); + cutlass::reference::host::TensorFill(abs_max_D.host_view()); + abs_max_D.sync_device(); + + reference_abs_max_D.resize({1, 1, 1, 1}); + } + + if (kScaleAux) { + tensor_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + cutlass::reference::host::TensorFill(tensor_Aux.host_view()); + tensor_Aux.sync_device(); + + scale_Aux.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits)); + scale_Aux.sync_device(); + + abs_max_Aux.resize({1, 1, 1, 1}); + cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); + abs_max_Aux.sync_device(); + + reference_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + reference_abs_max_Aux.resize({1, 1, 1, 1}); + } + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + if (kScaleAux) { + tensor_Aux.sync_host(); + abs_max_Aux.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); + passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view()); + passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view()); + } + + if (kScaleOutput) { + abs_max_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); + passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view()); + } + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + std::ofstream file0("conv_testbed_with_amax_errors_reference.txt"); + std::ofstream file1("conv_testbed_with_amax_errors_computed.txt"); + + std::ofstream file("conv_testbed_with_amax_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nVector =\n" << tensor_Vector.host_view() + << "\nScaleA = " << scale_A.host_view() + << "\nScaleB = " << scale_B.host_view() + << "\nScaleC = " << scale_C.host_view() + << "\nScaleD = " << scale_D.host_view() + << "\nScaleAux = " << scale_Aux.host_view() + << std::endl; + + file0 << "\n\nReference D =\n" << reference_D.host_view() << std::endl; + file1 << "\n\nComputed D =\n" << tensor_D.host_view() << std::endl; + if (kScaleAux) { + file0 << "\n\nReference Aux =\n" << reference_Aux.host_view() << std::endl; + file1 << "\n\nComputed Aux =\n" << tensor_Aux.host_view() << std::endl; + file0 << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() << std::endl; + file1 << "\n\nComputed Absmax Aux = " << abs_max_Aux.host_view() << std::endl; + } + if (kScaleOutput) { + file0 << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() << std::endl; + file1 << "\n\nComputed Absmax D = " << abs_max_D.host_view() << std::endl; + } + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha, + ElementCompute beta) { + + cutlass::Coord<4> origin(0); + ElementCompute scaled_alpha = alpha; + if (doScaleA) { + scaled_alpha *= scale_A.host_view().at(origin); + } + if (doScaleB) { + scaled_alpha *= scale_B.host_view().at(origin); + } + + ElementCompute scaled_beta = beta; + if (doScaleC) { + scaled_beta *= scale_C.host_view().at(origin); + } + + // + // Verify + // + + cutlass::reference::host::Conv2d< + typename Conv::ElementA, typename Conv::LayoutA, + typename Conv::ElementB, typename Conv::LayoutB, + typename Conv::ElementC, typename Conv::LayoutC, + ElementCompute, ElementAccumulator, ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tmp_D.host_ref(), + scaled_alpha, + scaled_beta + ); + + ElementCompute tmp_abs_max_Aux(0.); + ElementCompute tmp_abs_max_D(0.); + + cutlass::NumericConverter cvt_c_to_compute; + cutlass::NumericConverter cvt_accum_to_compute; + cutlass::NumericConverter cvt_compute_to_absmax; + cutlass::NumericConverter cvt_compute_to_d; + cutlass::NumericConverter cvt_compute_to_aux; + + cutlass::absolute_value_op abs; + cutlass::maximum_with_nan_propogation max; + ActivationFunctor act; + + ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); + + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({n, p, q, k})); + ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, 0, 0, k})); + ElementCompute aux = intermediate + bias; + ElementCompute d = act(aux); + tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); + tmp_abs_max_D = max(abs(d), tmp_abs_max_D); + reference_D.host_view().at({n, p, q, k}) = cvt_compute_to_d(d * d_scale); + + if (kScaleAux) { + reference_Aux.host_view().at({n, p, q, k}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); + } + } + } + } + } + if (kScaleAux) { + reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux); + } + + if (kScaleOutput) { + reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D); + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Conv::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta}; + typename Conv::EpilogueOutputOp::Params epilogue_params{ + activation_params, + scale_A.device_data(), + scale_B.device_data(), + scale_C.device_data(), + scale_D.device_data(), + scale_Aux.device_data(), + abs_max_Aux.device_data(), + abs_max_D.device_data() + }; + + typename Conv::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + tensor_Aux.device_ref(), + epilogue_params, + cutlass::conv::SplitKMode::kSerial, + tensor_Vector.device_data(), + 0 + }; + + Conv conv2d_op; + + cutlass::Status status = conv2d_op.can_implement(arguments); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + size_t workspace_size = Conv::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = conv2d_op.initialize(arguments, workspace.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + cudaError_t cuda_error = cudaDeviceSynchronize(); + EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed" << std::endl; + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ImplicitGemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAllConv2dWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) { + const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(); + const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector(); + + // + // Testbed object + // + + TestbedConv2dWithAbsMax testbed(scaleA, scaleB, scaleC); + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + bool passed = true; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Prune all problems with channels that aren't divisible by the number of elements accessed per + // load for operands A and B. This is meant to align with the requirements of iterators used for + // fprop kernels. + ChannelDivisibilitySpecification channel_spec(128 / cutlass::sizeof_bits::value); + auto pruned_problem_vector = prune(*problem_vector, channel_spec); + + // Run conv testbed on default convolution sizes + for(auto conv_problem : pruned_problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed &= testbed.run(conv_problem); + + if (!passed) { + return false; + } + + // test mode = convolution + passed &= testbed.run(conv_problem.reset_mode(cutlass::conv::Mode::kConvolution)); + + if (!passed) { + return false; + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..f768f5b25f425910a49058599d3854352136caef --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -0,0 +1,734 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM for fused epilogue broadcast testbed + + Parallel split-k is not tested because we can just use regular conv kernel + when we need to use parallel-splitk. Broadcast can happen in the reduction + kernel. +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Conv2dWithBroadcastReferenceOp { + + using OutputOp = typename Conv2d::EpilogueOutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + Conv2dWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) { + ElementCompute t_full = binary_op(conv2d, bias); + T = ElementT(t_full); + + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = CONV(AB, C) +// +// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k]) +// +// Z[n, p, q, k] = Elementwise(T[n, p, q, k]) +// + +template < + typename Conv2d, + typename ReferenceOp, + bool AddBroadcastFirst = false +> +class TestbedConv2dWithBroadcast { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + using ElementZ = typename EpilogueOutputOp::ElementZ; + using ElementT = typename EpilogueOutputOp::ElementT; + using ElementVector = typename EpilogueOutputOp::ElementVector; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + static const bool kAddBroadcastFirst = AddBroadcastFirst; + static const bool kStoreT = EpilogueOutputOp::kStoreT; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_C_reference; + cutlass::HostTensor tensor_Z_computed; + cutlass::HostTensor tensor_Z_reference; + cutlass::HostTensor tensor_T_computed; + cutlass::HostTensor tensor_T_reference; + cutlass::HostTensor tensor_Y_reference; + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + +public: + + TestbedConv2dWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Broadcast.resize({ + 1, + 1, + 1, + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), + }); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); + + for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { + for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { + for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { + for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { + tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k})); + } + } + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + tensor_C_reference.sync_device(); + tensor_Z_computed.sync_device(); + tensor_Z_reference.sync_device(); + tensor_T_computed.sync_device(); + tensor_T_reference.sync_device(); + tensor_Y_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_Z_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Broadcast.device_data(), + kStoreT ? tensor_T_computed.device_data() : nullptr, + 0, // This must be zero + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() + ); + + // initialize the kernel + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_T_computed.sync_host(); + tensor_Z_computed.sync_host(); + + // + // Reference check + // + + // When kAddBroadcastFirst is true, add bias on the host + ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C_reference.device_ref(), + tensor_Y_reference.device_ref(), + alpha, + beta_ref); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_Y_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C_reference.host_ref(), + tensor_Y_reference.host_ref(), + alpha, + beta_ref); + +#endif + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { + + ElementZ z{}; + ElementT t{}; + + ElementCompute accum = tensor_Y_reference.at({n, p, q, k}); + ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k})); + + + if (kAddBroadcastFirst) { + reference_op(z, t, accum + bias, + beta * ElementCompute(tensor_C_reference.at({n, p, q, k}))); + } else { + reference_op(z, t, accum, bias); + } + + tensor_Z_reference.at({n, p, q, k}) = z; + tensor_T_reference.at({n, p, q, k}) = t; + } + } + } + } + + if (kStoreT) { + passed = cutlass::reference::host::TensorEquals( + tensor_T_computed.host_view(), + tensor_T_reference.host_view()); + + EXPECT_TRUE(passed); + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Z_computed.host_view(), + tensor_Z_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" + << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" + << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" + << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" + << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" + << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; + } + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template , + bool AddBroadcastFirst = false> +bool TestSpecificConv2dWithBroadcast( + const Conv2dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithBroadcast testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template , + bool AddBroadcastFirst = false, + bool TestSplitK = true +> +bool TestAllConv2dWithBroadcast( + const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithBroadcast testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + + if (!TestSplitK) + return passed; + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..a8ec16ca5de369470f5dc50bb6f8b5e2da3da10d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h @@ -0,0 +1,643 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/tensor_reduce.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv2dWithReduction { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + using ElementT = typename EpilogueOutputOp::ElementTensor; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + + cutlass::HostTensor tensor_Reduction; + cutlass::HostTensor tensor_Tensor; + cutlass::HostTensor tensor_Final_Reduction; + + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + TestbedConv2dWithReduction( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope = 2; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + tensor_Reduction.resize({ + 1, + 1, + (problem_size.N * problem_size.P * problem_size.Q - 1 + Conv2d::ThreadblockShape::kM) / Conv2d::ThreadblockShape::kM, + (problem_size.K) + }); + + tensor_Final_Reduction.resize({ + 1, + 1, + 1, + (problem_size.K) + }); + + tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K}); + + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Reduction.device_data(), + tensor_Tensor.device_data(), + static_cast(tensor_Reduction.stride()[0]), + static_cast(tensor_Tensor.stride()[0]) + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + // Final reduction over the partial reduction tensor + using Functor = cutlass::plus; + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementAccumulator, + ElementAccumulator, + LayoutC, + Functor, + 8, + ElementAccumulator + >; + + TensorReduction reduction(tensor_Reduction.extent(), 2); + + cutlass::DeviceAllocation reduction_device_workspace(reduction.workspace_size()); + + status = reduction.reduce( + tensor_Final_Reduction.device_ref(), + tensor_Reduction.device_ref(), + reduction_device_workspace.get(), + ElementAccumulator()); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); + + // + // Reference check + // + + tensor_D_computed.sync_host(); + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + + EXPECT_TRUE(passed); + + // + // Reference check on reduction results + // + + tensor_Reduction.sync_host(); + tensor_Final_Reduction.sync_host(); + + // compute backwards for reduction results + cutlass::HostTensor reference_Reduction; + reference_Reduction.resize({ + 1, + 1, + 1, + (problem_size.K) + }); + + for (int k = 0; k < problem_size.K; ++k) { + ElementAccumulator reduced_value = ElementAccumulator(); + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + reduced_value += tensor_D_reference.at({n, p, q, k}); + } + } + } + reference_Reduction.at({0, 0, 0, k}) = reduced_value; + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Final_Reduction.host_view(), + reference_Reduction.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << tensor_D_reference.host_view() << "\n" + << "\nD computed:\n" << tensor_D_computed.host_view() << "\n" + << "\nreduction reference:\n" << reference_Reduction.host_view() << "\n" + << "\nreduction computed:\n" << tensor_Reduction.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllConv2dWithReduction( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithReduction testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + // Parallel SplitK is not tested. + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h new file mode 100644 index 0000000000000000000000000000000000000000..fae7d6194fb671594221a90faea7cac1e5fbeb9f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h @@ -0,0 +1,293 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed sizes for Conv2d problem +*/ +#pragma once + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_types.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +namespace test { +namespace conv { +namespace device { + +using Conv3dProblemVector = std::vector; + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedConv3dProblemSizes initializes and holds conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedConv3dProblemSizes { + + // + // Data members + // + int minimum_channel_size; + Conv3dProblemVector conv3d_default_sizes; + Conv3dProblemVector conv3d_vnet_medical_sizes; + + // + // Methods + // + /// Default ctor + TestbedConv3dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { + + initialize_conv3d_default_sizes(); + initialize_conv3d_vnet_medical_sizes(conv3d_vnet_medical_sizes, 1 /*batch-size*/); + + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv3dProblemVector *problems_vectors[] = { + &conv3d_default_sizes, + &conv3d_vnet_medical_sizes + }; + + for (Conv3dProblemVector *problems : problems_vectors) { + Conv3dProblemVector filtered; + + for (cutlass::conv::Conv3dProblemSize const & problem : *problems) { + if (!(problem.C % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_conv3d_default_sizes() { + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 3, 3, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 1, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) + CUTLASS_STL_NAMESPACE::make_tuple( + cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w) + ), + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + CUTLASS_STL_NAMESPACE::make_tuple( + cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w) + ), + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 16, 16, 16, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 15, 19, 160}, // input size (NDHWC) + {224, 1, 3, 6, 160}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 2, 1, 1, minimum_channel_size}, // input size (NDHWC) + {8, 2, 1, 1, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 7, 7, minimum_channel_size}, // input size (NDHWC) + {16, 1, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 11, 15, 19, 64}, // input size (NDHWC) + {32, 4, 3, 6, 64}, // filter size (KTRSC) + cutlass::Coord<3>({2, 1, 3}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + } + + // Add vnet layers to unit testing sizes + void initialize_conv3d_vnet_medical_sizes(Conv3dProblemVector &conv3d_problem_vector, int batch_size = 1) { + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 32, 32, 32, 16}, // input size (NDHWC) + {32, 2, 2, 2, 16}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {64, 2, 2, 2, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 64}, // input size (NDHWC) + {64, 3, 3, 3, 64}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 64}, // input size (NDHWC) + {128, 2, 2, 2, 64}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 4, 4, 4, 128}, // input size (NDHWC) + {128, 3, 3, 3, 128}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 128}, // input size (NDHWC) + {128, 3, 3, 3, 128}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 64}, // input size (NDHWC) + {64, 3, 3, 3, 64}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 32, 32, 32, 16}, // input size (NDHWC) + {64, 2, 2, 2, 16}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {128, 2, 2, 2, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + } + +}; + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..029f5effb9103bebd4ee61767795d3883541d986 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h @@ -0,0 +1,716 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/util/reference/host/convolution.h" + +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "conv3d_problems.h" +#include "cutlass/core_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv3d { +public: + + using ElementA = typename Conv3d::ElementA; + using LayoutA = typename Conv3d::LayoutA; + using ElementB = typename Conv3d::ElementB; + using LayoutB = typename Conv3d::LayoutB; + using ElementC = typename Conv3d::ElementC; + using LayoutC = typename Conv3d::LayoutC; + using ElementAccumulator = typename Conv3d::ElementAccumulator; + using ElementCompute = typename Conv3d::ElementCompute; + using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + TestbedConv3d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 4; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv3dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + + /// Executes one test + bool run( + cutlass::conv::Conv3dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute()) { + + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv3d conv3d_op; + + typename Conv3d::Arguments conv3d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + cutlass::Status status = conv3d_op.can_implement(conv3d_args); + if (status != cutlass::Status::kSuccess) { + std::cerr << "can_implement failed for the given problem_size: \n"; + return false; + } + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + status = conv3d_op.initialize(conv3d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv3d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv3d output is written to workspace in global memory + conv3d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv3d_args.output_op = {1.0, 0.0}; + // update conv3d operator arguments + status = conv3d_op.update(conv3d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv3d operator + status = conv3d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv3dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv3d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv3d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta + ); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta + ); +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv3d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv3d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv3d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "ndhwc_" + << problem_size.N << "x" + << problem_size.D << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_ktrsc_" + << problem_size.K << "x" + << problem_size.T << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_d << "x" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_d << "x" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_d << "x" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv3d::ThreadblockShape::kM << "x" + << Conv3d::ThreadblockShape::kN << "x" + << Conv3d::ThreadblockShape::kK << "_" + << Conv3d::WarpShape::kM << "x" + << Conv3d::WarpShape::kN << "x" + << Conv3d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllConv3d( + const Conv3dProblemVector & conv_test_sizes = Conv3dProblemVector(), + const Conv3dProblemVector & conv_blacklist_sizes = Conv3dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + //TestbedConv3d testbed(cutlass::Distribution::Sequential, cutlass::Distribution::Sequential, cutlass::Distribution::Sequential); + TestbedConv3d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv3d problem sizes to avoid duplicate runs + Conv3dProblemVector conv_tested_sizes; + + Conv3dProblemVector const *problem_vectors[] = { + &conv3d_problems.conv3d_default_sizes, + &conv3d_problems.conv3d_vnet_medical_sizes, + &conv_test_sizes + }; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv3dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity) || + (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport == + cutlass::conv::StrideSupport::kUnity))) { + if (!((conv_problem.stride_d == 1) && + (conv_problem.stride_h == 1) && + (conv_problem.stride_w == 1)) + ) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( + {1, 8, 8, 8, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +template +bool TestSpecificConv3d( + const Conv3dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3d testbed; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..f8ba785c9d0ecbdd518711714558c9e166c0209a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h @@ -0,0 +1,732 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM for fused epilogue broadcast testbed + + Parallel split-k is not tested because we can just use regular conv kernel + when we need to use parallel-splitk. Broadcast can happen in the reduction + kernel. +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv3d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Conv3dWithBroadcastReferenceOp { + + using OutputOp = typename Conv3d::EpilogueOutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + Conv3dWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute conv3d, ElementCompute bias) { + ElementCompute t_full = binary_op(conv3d, bias); + T = ElementT(t_full); + + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = CONV(AB, C) +// +// T[n, o, p, q, k] = ReductionOp(Y[n, o, p, q, k], Broadcast[k]) +// +// Z[n, o, p, q, k] = Elementwise(T[n, o, p, q, k]) +// + +template < + typename Conv3d, + typename ReferenceOp, + bool AddBroadcastFirst = false +> +class TestbedConv3dWithBroadcast { +public: + + using ElementA = typename Conv3d::ElementA; + using LayoutA = typename Conv3d::LayoutA; + using ElementB = typename Conv3d::ElementB; + using LayoutB = typename Conv3d::LayoutB; + using ElementC = typename Conv3d::ElementC; + using LayoutC = typename Conv3d::LayoutC; + using ElementAccumulator = typename Conv3d::ElementAccumulator; + using ElementCompute = typename Conv3d::ElementCompute; + using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; + using ElementZ = typename EpilogueOutputOp::ElementZ; + using ElementT = typename EpilogueOutputOp::ElementT; + using ElementVector = typename EpilogueOutputOp::ElementVector; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; + static const bool kAddBroadcastFirst = AddBroadcastFirst; + static const bool kStoreT = EpilogueOutputOp::kStoreT; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_C_reference; + cutlass::HostTensor tensor_Z_computed; + cutlass::HostTensor tensor_Z_reference; + cutlass::HostTensor tensor_T_computed; + cutlass::HostTensor tensor_T_reference; + cutlass::HostTensor tensor_Y_reference; + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + +public: + + TestbedConv3dWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv3dProblemSize const &problem_size, bool non_packed_test = false, uint64_t seed = 2019) { + + // to make the layout of tensors a little bit bigger than the problem size + cutlass::Tensor5DCoord stride_increment = cutlass::Tensor5DCoord(8, 16, 32, 32, 64); + + cutlass::Tensor5DCoord tensor_A_extent = implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size); + cutlass::Tensor5DCoord tensor_B_extent = implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size); + cutlass::Tensor5DCoord tensor_C_extent = implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size); + + if (non_packed_test) { + tensor_A_extent += stride_increment; + tensor_C_extent += stride_increment; + } + + tensor_A.resize(tensor_A_extent); + tensor_B.resize(tensor_B_extent); + tensor_C.resize(tensor_C_extent); + tensor_C_reference.resize(tensor_C_extent); + tensor_Z_computed.resize(tensor_C_extent); + tensor_Z_reference.resize(tensor_C_extent); + tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Y_reference.resize(tensor_C_extent); + tensor_Broadcast.resize({ + 1, + 1, + 1, + 1, + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), + }); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); + for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { + for (int o = 0; o < tensor_C_reference.extent().d(); ++o) { + for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { + for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { + for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { + tensor_C_reference.at({n, o, p, q, k}) = ElementAccumulator(tensor_C.at({n, o, p, q, k})); + } + } + } + } + } + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + tensor_C_reference.sync_device(); + tensor_Z_computed.sync_device(); + tensor_Z_reference.sync_device(); + tensor_T_computed.sync_device(); + tensor_T_reference.sync_device(); + tensor_Y_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv3dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + bool non_packed_test = false, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv3d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size, non_packed_test); + + // configure the operator + Conv3d conv3d_op; + typename Conv3d::Arguments conv3d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_Z_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Broadcast.device_data(), + kStoreT ? tensor_T_computed.device_data() : nullptr, + 0, // This must be zero + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() + ); + + // initialize the kernel + size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv3d_op.initialize(conv3d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv3d operator + status = conv3d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_T_computed.sync_host(); + tensor_Z_computed.sync_host(); + + // + // Reference check + // + + // When kAddBroadcastFirst is true, add bias on the host + ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C_reference.device_ref(), + tensor_Y_reference.device_ref(), + alpha, + beta_ref); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_Y_reference.sync_host(); + +#else + + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C_reference.host_ref(), + tensor_Y_reference.host_ref(), + alpha, + beta_ref); + +#endif + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int n = 0; n < problem_size.N; ++n) { + for (int o = 0; o < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Z : problem_size.D); ++o) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { + + ElementZ z{}; + ElementT t{}; + + ElementCompute accum = tensor_Y_reference.at({n, o, p, q, k}); + ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, 0, k})); + + + if (kAddBroadcastFirst) { + reference_op(z, t, accum + bias, + beta * ElementCompute(tensor_C_reference.at({n, o, p, q, k}))); + } else { + reference_op(z, t, accum, bias); + } + + tensor_Z_reference.at({n, o, p, q, k}) = z; + tensor_T_reference.at({n, o, p, q, k}) = t; + } + } + } + } + } + + if (kStoreT) { + passed = cutlass::reference::host::TensorEquals( + tensor_T_computed.host_view(), + tensor_T_reference.host_view()); + + EXPECT_TRUE(passed); + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Z_computed.host_view(), + tensor_Z_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv3d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "nnhwc_" + << problem_size.N << "x" + << problem_size.D << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.T << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_d << "x" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_d << "x" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_d << "x" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << (non_packed_test ? "non_packed_tensor_test_" : "packed_tensor_test_") + << Conv3d::ThreadblockShape::kM << "x" + << Conv3d::ThreadblockShape::kN << "x" + << Conv3d::ThreadblockShape::kK << "_" + << Conv3d::WarpShape::kM << "x" + << Conv3d::WarpShape::kN << "x" + << Conv3d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" + << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" + << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" + << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" + << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" + << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv3dProblemSizes +// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template , + bool AddBroadcastFirst = false, + bool TestSplitK = true +> +bool TestAllConv3dWithBroadcast( + const Conv3dProblemVector &conv_test_sizes = Conv3dProblemVector(), + const Conv3dProblemVector &conv_blacklist_sizes = Conv3dProblemVector(), + bool non_packed_test = false) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3dWithBroadcast testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv3d problem sizes to avoid duplicate runs + Conv3dProblemVector conv_tested_sizes; + + Conv3dProblemVector const *problem_vectors[] = { + &conv3d_problems.conv3d_default_sizes, + &conv3d_problems.conv3d_vnet_medical_sizes, + &conv_test_sizes + }; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv3dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_d == 1) && + (conv_problem.stride_h == 1) && + (conv_problem.stride_w == 1)) + ) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_d == 1) && (conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + } + } + + if (!TestSplitK) + return passed; + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv3d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( + {1, 8, 8, 8, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + false,/*non_packed_test*/ + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +template , + bool AddBroadcastFirst = false> +bool TestSpecificConv3dWithBroadcast( + const Conv3dProblemVector & problem_sizes, + bool non_packed_test = false) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3dWithBroadcast testbed; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross, non_packed_test = false + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + + // test mode = convolution, non_packed_test = false + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..cef5f981c595dfbbb95658fb757865b219538192 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h @@ -0,0 +1,473 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Depthwise Direct Conv testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "../cache_testbed_output.h" +#include "conv2d_problems.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/cutlass.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedDepthwiseDirectConv2d { + public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + public: + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_reordered_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + int tested_problem_count; + + public: + TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {} + + /// Helper to initialize a tensor view + template + void initialize_tensor(cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + if (dist_kind == cutlass::Distribution::Uniform) { + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } else { + scope = 5; + } + } else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0); + } else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + + } else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } else { + } + } + + void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_reordered_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient(int smem_size) const { + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < static_cast(smem_size)) { + return false; + } + + return true; + } + + /// Executes one test + bool run(cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1.5), + ElementCompute beta = ElementCompute(1)) { + // increment tested problem count run by the testbed + tested_problem_count++; + +#if 0 // display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " + << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") + << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args(problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + tensor_reordered_B.device_ref(), + split_k_mode); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.can_implement(problem_size); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + if (!sufficient(conv2d_op.get_smem_size())) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run." << std::endl; + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = + CreateCachedConv2dTestKey(kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view()); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d(kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d(kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + std::stringstream ss_problem_size_text; + ss_problem_size_text << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_DirectConv_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << ss_problem_size_text.str() + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) { + bool passed = true; + + // + // Testbed object + // + TestbedDepthwiseDirectConv2d testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (auto conv_problem : problem_sizes) { + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..54c11281e14b813b249d7f9710542843b37bcc68 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp @@ -0,0 +1,1385 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS 3.x Implicit GEMM testbed sizes for ConvNd problem +*/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +std::vector> +inline +get_conv_problem_vector(); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Fprop +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {800, 80, 1}, // stride (nwc) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {512, 64, 1}, // stride (nwc) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nqk) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, + {16,1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {96, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 64}, + {256, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 4, 64}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and tstride of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {1}, + {2}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {8000, 800, 80, 1}, // stride (nhwc) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {4096, 512, 64, 1}, // stride (nhwc) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {8000, 800, 80, 1}, // stride (npqk) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, + {16, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {96, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 64}, + {256, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {256, 3, 3, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,2/1,2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {256, 2, 5, 64}, + {1, 1}, + {2, 2}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 7, 7, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {2, 3}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 15, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {2, 3}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // ndhwc + {64, 1, 1, 1, 64}, // ktrsc + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // non-packed input output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // ndhwc + {8000, 8000, 800, 80, 1}, // stride (ndhwc) + {64, 1, 1, 1, 64}, // ktrsc + {64, 64, 64, 64, 1}, // stride (ktrsc) + {8000, 8000, 800, 80, 1}, // stride (nzpqk) + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, + {16, 1, 1, 1, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // N = 7 and K = 256 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 64}, + {96, 1, 1, 1, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x3x3 + no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 3, 3, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x3x3 + symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 3, 3, 32}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + symmetric padding 111 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 4, 5, 64}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {2, 2, 3}, + 1 + }); + return problem_shapes; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Wgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, + {16,1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {96, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 64}, + {256, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 4, 32}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and tstride of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {1}, + {2}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1024, 128}, + {640, 1, 128}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1040, 128}, + {640, 1, 128}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, + {16, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {96, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 64}, + {256, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 15, 16, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {2, 3}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 15, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {2, 3}, + {2, 3}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 64, 16, 128}, + {640, 1, 1, 128}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 65, 16, 128}, + {640, 1, 1, 128}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 64}, // ndhwc + {64, 1, 1, 1, 64}, // ktrsc + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // Filter 3x3x3 + no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 3, 3, 32}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 64, 16, 128}, + {640, 1, 1, 1, 128}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 65, 16, 128}, + {640, 1, 1, 1, 128}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Grouped Wgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Get problem size vectors for group conv problems +template +std::vector> +inline +get_grouped_conv_problem_vector(int GroupsPerTile); + +// Specialization for 3D wgrad problems +template<> +std::vector> inline +get_grouped_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>(int GroupsPerTile) { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + + if (GroupsPerTile == 1) { + // channel_per_group == 64 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 2048}, // ndhwc + {2048, 1, 3, 3, 64}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 2) { + // channel_per_group == 32 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 1024}, // ndhwc + {1024, 1, 3, 3, 32}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 4) { + // channel_per_group == 16 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 512}, // ndhwc + {512, 1, 3, 3, 16}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 8) { + // channel_per_group == 8 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 256}, // ndhwc + {256, 1, 3, 3, 8}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Unit Stride Dgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {512, 64, 1}, // stride (nqk) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 16}, + {64, 1, 16}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 96}, + {64, 1, 96}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 256}, + {64, 1, 256}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {64, 3, 256}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding with k % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {32, 3, 256}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {64, 4, 256}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 64}, + {256, 3, 64}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {8000, 800, 80, 1}, // stride (npqk) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {4096, 512, 64, 1}, // stride (npqk) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {8000, 800, 80, 1}, // stride (nhwc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 16}, + {64, 1, 1, 16}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 96}, + {64, 1, 1, 96}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 256}, + {64, 1, 1, 256}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {64, 3, 3, 256}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding with k % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {32, 3, 3, 256}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {64, 2, 5, 256}, + {1, 1}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 16}, + {64, 1, 1, 1, 16}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // non-packed input output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // nzpqk + {8000, 8000, 800, 80, 1}, // stride (nzpqk) + {64, 1, 1, 1, 64}, // ktrsc + {64, 64, 64, 64, 1}, // stride (ktrsc) + {8000, 8000, 800, 80, 1}, // stride (ndhwc) + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // N = 7 and K = 256 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 96}, + {64, 1, 1, 1, 96}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + symmetric padding 111 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 96}, + {64, 3, 4, 5, 96}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 96}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Strided Dgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Test TMA truncation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 512, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {2}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1024, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 2048, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {8}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // stride divides dilation + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {0}, // padding lower (pad_w) + {1}, // padding upper (pad_w) + {2}, // stride (stride_w) + {4}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // dilation divides stride + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {1}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {2}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // stride dilation dont divide + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {1}, // padding lower (pad_w) + {2}, // padding upper (pad_w) + {2}, // stride (stride_w) + {3}, // dilation (dilation_w) + 1 // group + }); + return problem_shapes; +} + +// Specialization for 2D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // mode 0 stride divides dilation + // mode 1 dilation divides stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {2, 4}, + {4, 2}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // mode 0 dilation divides stride + // mode 1 stride divides dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {4, 2}, + {2, 4}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // stride dilation dont divide + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {3, 2}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {2, 1, 2}, + {4, 2, 3}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..99ba9c407cec38e919812fedeee38ba75d9129f7 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp @@ -0,0 +1,768 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed for 3.x API +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "../../common/cutlass_unit_test.h" + +#include "cute/tensor.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "../test/unit/gemm/device/gemm_testbed_3x.hpp" + +#include "thrust/universal_vector.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/conv.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "conv_problem_sizes.hpp" +#include "../cache_testbed_output.h" + +#include + +#include "cute/layout.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Initializes a flat device buffer +template +static void +initialize_values( + thrust::universal_vector& dst_ptr, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + if (cutlass::Distribution::Uniform == dist_kind) { + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 4; + } + else { + scope = 8; + } + cutlass::reference::host::BlockFillRandomUniform( + dst_ptr.data().get(), dst_ptr.size(), seed, scope, -scope, 0); + } + else if (cutlass::Distribution::Identity == dist_kind) { + cutlass::reference::host::BlockFillRandomUniform( + dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0, 0); + } + else if (cutlass::Distribution::Gaussian == dist_kind) { + cutlass::reference::host::BlockFillRandomGaussian(dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0.5); + } + else if (cutlass::Distribution::Sequential == dist_kind) { + cutlass::reference::host::BlockFillSequential(dst_ptr.data().get(), dst_ptr.size()); + } + else { + std::cerr << "Invalid distribution kind!\n."; + exit(1); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// utils for sparse or dense conv parameters + +template +struct DenseConvParams { + // Default Kernel data types + using ElementA = typename Conv::ConvKernel::ElementA; + using ElementB = typename Conv::ConvKernel::ElementB; + + static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; + using ProblemShape = cutlass::conv::ConvProblemShape; + + // get the default arguments without sparse data + auto get_mainloop_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + thrust::universal_vector& tensor_A, + thrust::universal_vector& tensor_B + ) { + auto args = typename Conv::ConvKernel::MainloopArguments { + tensor_A.data().get(), + tensor_B.data().get(), + }; + return args; + } +}; + +template +struct SparseConvParams { +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ConvTestbed { + // Kernel data types + using ElementA = typename Conv::ConvKernel::ElementA; + using ElementB = typename Conv::ConvKernel::ElementB; + using ElementC = cute::conditional_t, + typename Conv::ConvKernel::ElementD, typename Conv::ConvKernel::ElementC>; + using ElementD = typename Conv::ConvKernel::ElementD; + using ElementAccumulator = typename Conv::ConvKernel::ElementAccumulator; + + // ConvTest for sparse kernel + static constexpr bool isSparseEnabled = isSparseEnabled_; + using ConvParams = cute::conditional_t, DenseConvParams>; + ConvParams params; + + // + // FusionOperation derived types/queries + // + using FusionOp = typename Conv::EpilogueOutputOp; + + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementCompute = typename FusionOp::ElementCompute; + using BiasType = typename cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::type; + using ElementBias = non_void_t; + using ActivationType = non_void_t::type, + cutlass::epilogue::thread::Identity>; + static constexpr bool IsActivationEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithActivation::value; + using ActivationFunctor = cute::conditional_t>; + + static constexpr bool IsBiasEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::value && + !cute::is_same_v; + static constexpr bool IsPerChannelScaleEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithPerChannelScaled::value; + + static constexpr bool DisableSource = cute::is_void_v; + + static constexpr bool IsResidualEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithResidualAdd::value; + + using StrideC = typename Conv::ConvKernel::StrideC; + using StrideD = typename Conv::ConvKernel::StrideD; + using ThreadEpilogueOp = typename Conv::ConvKernel::CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; + using ProblemShape = cutlass::conv::ConvProblemShape; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; + using Splits = typename gemm::device::detail::Splits; + + using Schedule = typename Conv::DispatchPolicy::Schedule; + /// Initialization + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_disable = cutlass::Distribution::Identity; // all zeros + uint64_t seed = 6090; + float epsilon = 0.0f; + int split_p_slices = 1; + thrust::universal_vector tensor_A; + thrust::universal_vector tensor_B; + thrust::universal_vector tensor_C; + thrust::universal_vector tensor_D_computed; + thrust::universal_vector tensor_D_reference; + thrust::universal_vector tensor_bias; + thrust::universal_vector tensor_alpha; + thrust::universal_vector tensor_beta; + + // Return true on success, else false + bool initialize(ProblemShape const& problem_shape, uint64_t seed = 6090) { + tensor_A.resize(sizeof(ElementA) * problem_shape.size_A()); + tensor_B.resize(sizeof(ElementB) * problem_shape.size_B()); + tensor_C.resize(sizeof(ElementC) * problem_shape.size_C()); + tensor_D_computed.resize(sizeof(ElementD) * problem_shape.size_C()); + tensor_D_reference.resize(sizeof(ElementD) * problem_shape.size_C()); + tensor_bias.resize(sizeof(ElementBias) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + if constexpr (IsPerChannelScaleEnabled) { + tensor_alpha.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + tensor_beta.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + } + initialize_values(tensor_A, init_A, seed); + initialize_values(tensor_B, init_B, seed * 11); + initialize_values(tensor_C, init_C, seed * 17); + initialize_values(tensor_bias, init_bias, seed * 19); + if constexpr (IsPerChannelScaleEnabled) { + initialize_values(tensor_alpha, init_bias, seed * 23); + if constexpr (DisableSource) { + initialize_values(tensor_beta, init_disable, seed * 27); + } + else { + initialize_values(tensor_beta, init_bias, seed * 27); + } + } + + bool flag = true; + if constexpr (isSparseEnabled) { + flag &= params.initialize(problem_shape, tensor_B, static_cast(seed + 2023)); + } + + return flag; + } + + // Determine SMEM requirements and waive if not satisfied + bool sufficient() const { + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + int max_smem_size; + result = cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaDeviceGetAttribute() failed"); + } + + return max_smem_size >= Conv::ConvKernel::SharedStorageSize; + } + + auto transform_shape_and_stride_with_groups(ProblemShape const& problem_shape) { + using TensorExtent = cute::array; + using TensorStride = cute::array; + + TensorExtent shape_a_g{}; + TensorExtent shape_b_g{}; + TensorExtent shape_c_g{}; + TensorStride stride_a_g{}; + TensorStride stride_b_g{}; + TensorStride stride_c_g{}; + + auto shape_a = cute::reverse(problem_shape.shape_A); + auto shape_b = cute::reverse(problem_shape.shape_B); + auto shape_c = cute::reverse(problem_shape.shape_C); + auto stride_a = cute::reverse(problem_shape.stride_A); + auto stride_b = cute::reverse(problem_shape.stride_B); + auto stride_c = cute::reverse(problem_shape.stride_C); + + int32_t G = problem_shape.groups; + + if constexpr (ConvOp == cutlass::conv::Operator::kFprop || + ConvOp == cutlass::conv::Operator::kDgrad) { + // shape_a_g = (c,w,h,d,n,g) or (k,q,p,z,n,g) + // shape_b_g = (c,s,r,k,t,g) + // shape_c_g = (k,q,p,z,n,g) or (c,w,h,d,n,g) + shape_a_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_a) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_a), + cute::make_shape(G))); + shape_b_g = cute::to_array(tuple_cat( + cute::take<0,NumSpatialDimensions + 1>(shape_b), + cute::make_shape(cute::size(shape_b) / G, G))); + shape_c_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_c) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_c), + cute::make_shape(G))); + + stride_a_g = cute::to_array(append(stride_a, cute::size<0>(shape_a) / G)); + stride_b_g = cute::to_array(append(stride_b, + cute::size(stride_b) * cute::size(shape_b) / G)); + stride_c_g = cute::to_array(append(stride_c, cute::size<0>(shape_c) / G)); + } + else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + // shape_a_g = (k,q,p,z,n,g) + // shape_b_g = (c,w,h,d,n,g) + // shape_c_g = (c,s,r,k,t,g) + shape_a_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_a) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_a), + cute::make_shape(G))); + shape_b_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_b) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_b), + cute::make_shape(G))); + shape_c_g = cute::to_array(tuple_cat( + cute::take<0,NumSpatialDimensions + 1>(shape_c), + cute::make_shape(cute::size(shape_c) / G, G))); + + stride_a_g = cute::to_array(append(stride_a, cute::size<0>(shape_a) / G)); + stride_b_g = cute::to_array(append(stride_b, cute::size<0>(shape_b) / G)); + stride_c_g = cute::to_array(append(stride_c, + cute::size(stride_c) * cute::size(shape_c) / G)); + } + + return make_tuple(shape_a_g, shape_b_g, shape_c_g, + stride_a_g, stride_b_g, stride_c_g); + } + + // Executes one test + bool run( + ProblemShape const& problem_shape, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + dim3 cluster_shape = dim3(0, 0, 0), + dim3 cluster_shape_fallback = dim3(0, 0, 0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + MaxSwizzleSize max_swizzle = MaxSwizzleSize{}, + Splits splits = Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic + ) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device.\n"; + } + return true; + } + + bool ret = initialize(problem_shape); + + if (!ret) { + std::cerr << "initialize failed for the given problem_shape: \n"; + return false; + } + + cutlass::KernelHardwareInfo hw_info; + cudaGetDevice(&hw_info.device_id); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + hw_info.cluster_shape = cluster_shape; + hw_info.cluster_shape_fallback = cluster_shape_fallback; + + // configure the operator + Conv conv_op; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); + } + // Need to support non-packed output strides for fprop and dgrad kernel. + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { + cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { + cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + } + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{}; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + + auto mainloop_args = params.get_mainloop_arguments(problem_shape, tensor_A, tensor_B); + + auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments { + {}, + tensor_C.data().get(), + stride_C, + tensor_D_computed.data().get(), + stride_D, + }; + + auto args = typename Conv::Arguments { + problem_shape, + mainloop_args, // MainloopArguments + epilogue_args, // EpilogueArguments + hw_info, + scheduler_args + }; + + auto &fusion_args = args.epilogue.thread; + + fusion_args.alpha = alpha; + fusion_args.beta = beta; + + if constexpr (IsPerChannelScaleEnabled) { + fusion_args.alpha_ptr = tensor_alpha.data().get(); + fusion_args.beta_ptr = tensor_beta.data().get(); + } + + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = tensor_bias.data().get(); + } + + // Clamp bound + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::lowest(); + fusion_args.activation.upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); + } + + // Scale + if constexpr (cute::is_same_v> || + cute::is_same_v> || + cute::is_same_v> || + cute::is_same_v> ) { + fusion_args.activation.scale = ElementCompute{1}; + } + + // LeakyRelu + if constexpr (cute::is_same_v> ) { + fusion_args.activation.leaky_alpha = ElementCompute{0}; + } + + cutlass::Status status = cutlass::Status::kInvalid; + + status = conv_op.can_implement(args); + EXPECT_EQ(conv_op.can_implement(args), cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "can_implement failed for the given problem_shape: \n"; + print(problem_shape); + return false; + } + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv::get_workspace_size(args); + thrust::universal_vector workspace(workspace_size); + + status = conv_op.initialize(args, workspace.data().get()); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv3d operator + status = conv_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " Kernel execution error: " + << cudaGetErrorString(result); + + // Create cute::Tensors using the logical rank-3 MNK multi-mode shapes the mainloop gives us + auto [shape_mA, shape_mB, shape_mC, stride_mA, stride_mB, stride_mC] = + transform_shape_and_stride_with_groups(problem_shape); + auto shape_mBias = cute::make_shape(cute::size(cute::get<0>(problem_shape.get_shape_B()))); + + auto mA = make_tensor(tensor_A.data().get(), make_layout(shape_mA, stride_mA)); + auto mB = make_tensor(tensor_B.data().get(), make_layout(shape_mB, stride_mB)); + auto mC = make_tensor(tensor_C.data().get(), make_layout(shape_mC, stride_mC)); + auto mD_ref = make_tensor(tensor_D_reference.data().get(), make_layout(shape_mC, stride_mC)); + auto mD_computed = make_tensor(tensor_D_computed.data().get(), make_layout(shape_mC, stride_mC)); + auto mBias = make_tensor(tensor_bias.data().get(), make_layout(shape_mBias)); + auto mAlpha = make_tensor(tensor_alpha.data().get(), make_layout(shape_mBias)); + auto mBeta = make_tensor(tensor_beta.data().get(), make_layout(shape_mBias)); + + cutlass::reference::host::ConvEpilogueFusionParams< + ElementAccumulator, + ElementScalar, + ElementCompute, + ElementC, + ElementD, + IsResidualEnabled, + decltype(mAlpha), + decltype(mBeta), + decltype(mBias), + ActivationFunctor> + epilogue_fusion_params{}; + + epilogue_fusion_params.alpha = alpha; + epilogue_fusion_params.beta = beta; + + if constexpr (IsPerChannelScaleEnabled) { + epilogue_fusion_params.tensor_alpha = mAlpha; + epilogue_fusion_params.tensor_beta = mBeta; + } + + if constexpr (IsBiasEnabled) { + epilogue_fusion_params.tensor_bias = mBias; + } + + auto padding = cute::reverse(problem_shape.lower_padding); + auto tstride = cute::reverse(problem_shape.traversal_stride); + auto dilation = cute::reverse(problem_shape.dilation); + + cutlass::reference::host::ConvReferenceImpl< + ConvOp, + NumSpatialDimensions, + decltype(mA), + decltype(mB), + decltype(mC), + decltype(mD_ref), + decltype(padding), + decltype(tstride), + decltype(dilation), + decltype(epilogue_fusion_params)> + reference_impl(mA, mB, mC, mD_ref, padding, tstride, dilation, epilogue_fusion_params); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConvNd3xTestKey< + ProblemShape, + ElementA, + ElementB, + ElementC, + ElementD + >( + ConvOp, + problem_shape, + alpha, + beta, + tensor_A, + tensor_B, + tensor_C + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string convnd_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + CachedTestResultListing cached_results(convnd_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + #endif + + if (!cached_result_loaded) { + // Compute reference + reference_impl.compute_reference(); + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + cached_test_result.D = TensorHash(tensor_D_reference); + CachedTestResultListing cached_results(convnd_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(convnd_result_cache_name); + #endif + } // if (!cached_result_loaded) + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + uint32_t tensor_D_computed_hash = TensorHash(tensor_D_computed); + passed = (tensor_D_computed_hash == cached_test_result.D); + // If hash fails, double check against reference implementation. + if(!passed) { + std::cerr << "Hash-based comparison unsuccessful for key:" << "\n" << cached_test_key + << ", comparing with reference implementation now.\n"; + if (cached_result_loaded) { + // Compute reference + reference_impl.compute_reference(); + } + // Validate kernel against reference + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); + } + #else + // Validate kernel against reference + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); + #endif + + EXPECT_TRUE(passed); + return passed; + } + + template< + class Engine, class Layout, + class EngineA, class LayoutA, + class EngineB, class LayoutB, + class EngineAlpha, class LayoutAlpha, + class EngineBeta, class LayoutBeta, + class EngineBias, class LayoutBias> + static constexpr bool + compare_reference( + cute::Tensor const& reference, + cute::Tensor const& computed, + cute::Tensor const& A, + cute::Tensor const& B, + cute::Tensor const& tensor_alpha, + cute::Tensor const& tensor_beta, + cute::Tensor const& tensor_bias, + float epsilon = 0.0f) { + if (size(reference) != size(computed)) { + return false; + } + + bool passed = true; + if (epsilon == 0.0f) { + // fast refcheck w/o epsilon + for (size_t i = 0; i < size_t(size(reference)); ++i) { + if (reference(i) != computed(i)) { + passed = false; + printf("[%llu] %f, %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + break; + } + } + } else { + // refcheck with epsilon + for (size_t i = 0; i < size_t(size(reference)); ++i) { + auto ref = static_cast(reference(i)); + auto act = static_cast(computed(i)); + auto abs_error = std::abs(act - ref); + auto rel_error = abs_error / (std::max(std::abs(act), std::abs(ref)) + 0.00001f); + if (std::isnan(abs_error) || std::isnan(rel_error) || + std::min(abs_error, rel_error) > epsilon) { + passed = false; + printf("[%llu] %f, %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + break; + } + } + } + #if CUTLASS_DEBUG_TRACE_LEVEL > 1 + if (not passed) { + cute::print("Reference:"); + cute::print_tensor(reference); + cute::print("\nComputed:"); + cute::print_tensor(computed); + cute::print("\n"); + + for (size_t i = 0; i < size_t(size(A)); ++i) { + printf("[%llu]: A = %f\n", static_cast(i), float(A(i))); + } + for (size_t i = 0; i < size_t(size(B)); ++i) { + printf("[%llu]: B = %f\n", static_cast(i), float(B(i))); + } + if constexpr (IsPerChannelScaleEnabled) { + for (size_t i = 0; i < size_t(size(tensor_alpha)); ++i) { + printf("[%llu]: alpha = %f\n", static_cast(i), + float(tensor_alpha(i))); + } + for (size_t i = 0; i < size_t(size(tensor_beta)); ++i) { + printf("[%llu]: beta = %f\n", static_cast(i), + float(tensor_beta(i))); + } + } + if constexpr (IsBiasEnabled) { + for (size_t i = 0; i < size_t(size(tensor_bias)); ++i) { + printf("[%llu]: bias = %f\n", static_cast(i), + float(tensor_bias(i))); + } + } + for (size_t i = 0; i < size_t(size(reference)); ++i) { + printf("[%llu]: ref = %f, computed = %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + } + } + #endif + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f, + dim3 cluster_shape = dim3(0, 0, 0), + dim3 cluster_shape_fallback = dim3(0, 0, 0) + ) { + using ElementScalar = typename Conv::EpilogueOutputOp::ElementScalar; + + bool passed = true; + ConvTestbed testbed; + testbed.epsilon = epsilon; + auto problem_vector = get_conv_problem_vector< + Conv::NumSpatialDimensions, Conv::DispatchPolicy::ConvOp, SupportStrides>(); + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; + using Splits = typename gemm::device::detail::Splits; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + + for (auto conv_problem : problem_vector) { + #if CUTLASS_DEBUG_TRACE_LEVEL > 0 + print(conv_problem); + #endif + for (DecompositionMode decomp_mode : decomposition_modes) { + std::vector problem_splits = {Splits{1}}; + if constexpr (UsesStreamKScheduler) { + if (decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(Splits{2}); + problem_splits.push_back(Splits{4}); + } + } + for (auto splits : problem_splits) { + + passed = testbed.run( + conv_problem, + cutlass::from_real(alpha), + cutlass::from_real(beta), + cluster_shape, + cluster_shape_fallback, + RasterOrderOptions::Heuristic, // raster_order + MaxSwizzleSize(1), + splits, + decomp_mode + ); + if (!passed) { + printf("Failed test for "); print(conv_problem); + return false; + } + } // splits + } // decomposition_mode + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace test::conv::device + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff170be142ff9d0d02cc684c2873c3ec014bd236 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +using namespace cute; + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; +}; + +template +__global__ void +test_tiled_cp_async_device_cute(T const* g_in, T* g_out, + TiledCopy const tiled_copy, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + auto thr_copy = tiled_copy.get_slice(threadIdx.x); + Tensor gA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor gB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); + + auto tAgA = thr_copy.partition_S(gA); + auto tAsA = thr_copy.partition_D(sA); + +#if 0 + if (thread0()) { + print("gA : "); print(gA.layout()); print("\n"); + print("sA : "); print(sA.layout()); print("\n"); + print("tAgA: "); print(tAgA.layout()); print("\n"); + print("tAsA: "); print(tAsA.layout()); print("\n"); + } +#endif + + copy(tiled_copy, tAgA, tAsA); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // Store trivially smem -> gmem + + if (thread0()) { + copy(sA, gB); + } + +} + +template +void +test_tiled_cp_async( + TiledCopy const tiled_copy, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast(i % 13); } + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + test_tiled_cp_async_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tiled_copy, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < size(hA_out) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } +} + +template +void test_cp_async_no_swizzle() { + using namespace cute; + auto smem_atom = SMEM_LAYOUT{}; + auto smem_layout = tile_to_shape(smem_atom, Shape{}); + auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); + test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); +} + +template +void test_cp_async_with_swizzle() { + using namespace cute; + auto swizzle_atom = SWIZZLE_ATOM{}; + auto smem_atom = composition(swizzle_atom, SMEM_LAYOUT{}); + auto smem_layout = tile_to_shape(smem_atom, Shape{}); + auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); + test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3ff20d4087ee2fd6f4f74338e3e63eef27c221d3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp @@ -0,0 +1,775 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/relatively_equal.h" +#include "cutlass_unit_test.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include + +#include +#include + +#include + +using namespace cute; + +template +struct fp64_tester { + using value_type = double; +}; + +template +struct fp64_tester> { + using value_type = complex; +}; + +template // logical shape (M, N) +auto host_generate_gemm_inputs( + ALayout a_layout, + BLayout b_layout, + CLayout c_layout +) { + thrust::host_vector h_a(cosize(a_layout)); + thrust::host_vector h_b(cosize(b_layout)); + thrust::host_vector h_c(cosize(c_layout)); + thrust::host_vector h_c_out(cosize(c_layout)); + + auto h_a_tensor = make_tensor(h_a.data(), a_layout); + auto h_b_tensor = make_tensor(h_b.data(), b_layout); + auto h_c_tensor = make_tensor(h_c.data(), c_layout); + size_t max_size = std::max({static_cast(size(a_layout)), + static_cast(size(b_layout)), + static_cast(size(c_layout))}); + for (size_t i = 0; i < max_size; ++i) { + double di = static_cast(i); + if(i < size(a_layout)) { + h_a_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(b_layout)) { + h_b_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(c_layout)) { + h_c_tensor(i) = static_cast((di*di) / size(a_layout)); + } + } + + return std::make_tuple(h_a, h_b, h_c, h_c_out); +} + +template +thrust::host_vector +host_reference_gemm(Alpha alpha, + Tensor const& h_a_tensor, + Tensor const& h_b_tensor, + Beta beta, + Tensor const& h_c_tensor, + ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) + { + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TA = remove_cv_t; + using TB = remove_cv_t; + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + + thrust::host_vector h_c_ref(cosize(h_c_tensor.layout()), static_cast(0.0)); + auto h_c_ref_tensor = make_tensor(h_c_ref.data(), h_c_tensor.layout()); + // A * B + for (int k = 0; k < size<1>(h_a_tensor); k++) { + for (int m = 0; m < size<0>(h_a_tensor); m++) { + for (int n = 0; n < size<0>(h_b_tensor); n++) { + const auto a_value = a_load_transform(h_a_tensor(m, k)); + const auto b_value = b_load_transform(h_b_tensor(n, k)); + const auto a_value_fp64 = static_cast(a_value); + const auto b_value_fp64 = static_cast(b_value); + h_c_ref_tensor(m, n) += static_cast(a_value_fp64 * b_value_fp64); + } + } + } + // C = A*B + C + for (int i = 0; i < size(h_c_ref_tensor); i++) { + const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); + const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); + h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); + } + + return h_c_ref; +} + +template +void verify_gemm_correctness(cute::Tensor const& h_c_out_tensor, + cute::Tensor const& h_c_ref_tensor) +{ + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + + for (int i = 0; i < size(h_c_ref_tensor); i++) { + ABC_64 h_c_ref_i = h_c_ref_tensor(i); + ABC_64 h_c_out_i = h_c_out_tensor(i); + double epsilon(0.1f); + double nonzero_floor(std::numeric_limits::min()); + bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor); + ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i; + } +} + + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + Alpha const alpha, + Beta const beta, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op, + SMemCopyLdOpC c_copy_ld_op, + SMemCopyStOpC c_copy_st_op) +{ + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(smem_b_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), smem_c_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + cooperative_copy(threadIdx.x, g_c_tensor, s_c_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + cooperative_gemm( + threadIdx.x, tiled_mma, + alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_copy_op, b_copy_op, c_copy_ld_op, c_copy_st_op + ); + __syncthreads(); + + cooperative_copy(threadIdx.x, s_c_tensor, g_c_out_tensor); +} + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op) + { + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // Create C fragment for storing intermediate results + auto thr_mma = TiledMma().get_thread_slice(threadIdx.x); + Tensor g_c_partition = thr_mma.partition_C(g_c_tensor); + Tensor g_c_out_partition = thr_mma.partition_C(g_c_out_tensor); + Tensor r_c_partition = thr_mma.make_fragment_C(g_c_partition); + + // Create indexing help for predicated GEMMs + Tensor cC = make_identity_tensor(shape(gmem_c_layout)); + Tensor tCcC = thr_mma.partition_C(cC); + + // Load C from global + // (always loading in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + r_c_partition(i) = c_load_transform(g_c_partition(i)); + } + } + + cooperative_gemm( + threadIdx.x, tiled_mma, s_a_tensor, s_b_tensor, r_c_partition, + a_load_transform, b_load_transform, a_copy_op, b_copy_op + ); + + __syncthreads(); + + // Store C to global + // (always storing in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + g_c_out_partition(i) = c_store_transform(r_c_partition(i)); + } + } +} + +template, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyLdOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyStOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}, + CSMemCopyLdOp c_smem_copy_ld_op = {}, + CSMemCopyStOp c_smem_copy_st_op = {}) +{ + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK + + static_assert(size<0>(smem_a_layout) == size<0>(smem_c_layout)); // AM == CM + static_assert(size<0>(smem_b_layout) == size<1>(smem_c_layout)); // BN == CN + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + static_assert(cute::size(gmem_c_layout) == cute::size(smem_c_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.1); + const auto beta = static_cast(1.2); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) + + sizeof(TC) * h_c.size(); + + + auto kernel = cooperative_gemm_kernel< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, SMemCLayout, + TA, TB, TC, decltype(alpha), decltype(beta), + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp, CSMemCopyLdOp, CSMemCopyStOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + alpha, + beta, + tiled_mma, + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform, + a_smem_copy_op, + b_smem_copy_op, + c_smem_copy_ld_op, + c_smem_copy_st_op + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Copy result data + h_c_out = d_c_out; + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); +} + +template, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}) +{ + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK + + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.0); + const auto beta = static_cast(1.0); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = + host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), static_cast(-1)); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes); + + + auto kernel = cooperative_gemm_kernel_rmem_c< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, + TA, TB, TC, + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + tiled_mma, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_smem_copy_op, b_smem_copy_op + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + // Copy result data + h_c_out = d_c_out; + + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); +} + +template +void test_cooperative_gemm_col_major_layout(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + test_cooperative_gemm + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma, + ops...); +} + + +template +std::enable_if_t, + cute::is_layout, + cute::is_layout>> +test_cooperative_gemm_col_major_layout(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + SMemAtomLayoutC smem_atom_layout_c, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops&& ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + auto smem_c_layout = tile_to_shape( + smem_atom_layout_c, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); + + test_cooperative_gemm + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma, + ops...); +} + + +template +void test_cooperative_gemm_col_major_layout_rmem_c(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + + test_cooperative_gemm_rmem_c + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + tiled_mma, + ops...); +} + +template +std::enable_if_t, + cute::is_layout>> +test_cooperative_gemm_col_major_layout_rmem_c(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + test_cooperative_gemm_rmem_c + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + tiled_mma, + ops...); +} + +template +void test_cooperative_gemm_col_major_layout_rmem_c(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout_rmem_c, + T, T, T> + (static_cast(args)...); +} + +template +void test_cooperative_gemm_col_major_layout(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout, + T, T, T> + (static_cast(args)...); +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d2620e62ff247e36ae49809ab4ef3416560ae31 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp @@ -0,0 +1,217 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, + CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + constexpr int R = rank_v; + Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + + // + // Prepare the TMA_LOAD + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N) + Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) + +#if 0 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA_x: "); print(tAgA_x); print("\n"); + print("tAsA_x: "); print(tAsA_x); print("\n"); + } +#endif + + // + // Perform the TMA_LOAD + // + + // INPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles + Tensor tAgA = group_modes<1,rank(tAgA_x)>(tAgA_x); // (TMA,REST) + Tensor tAsA = group_modes<1,rank(tAsA_x)>(tAsA_x); // (TMA,REST) + static_assert(size<1>(tAsA) == 1); + + // OUTPUT: Group the CTA_TILE_X modes and REST_X modes for output + Tensor tBgB = group_modes<0,R>(group_modes(gB)); // (CTA_TILE, REST) + +#if 0 + if (thread0()) { + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + } +#endif + + // Test L2 prefetch + if (threadIdx.x == 0) { + prefetch(tma, tAgA); + } + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = sizeof(make_tensor_like(tensor<0>(tAsA))); + + if (threadIdx.x == 0) + { + /// Initialize shared memory barrier + tma_load_mbar[0] = 0; + cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0]), tAgA(_,stage), tAsA(_,0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value + constexpr int kPhaseBit = 0; + cute::wait_barrier(tma_load_mbar[0], kPhaseBit); + + // + // Write out trivially smem -> gmem + // + + // Subbyte elements could cause race conditions, so be even more conservative + if (thread0()) { + copy(sA, tBgB(_,stage)); + } + + __syncthreads(); + } +} + +template +auto +test_tma_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); + //print(tma); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tma, cta_tile, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3e0ec46df1b672c35c3c38f731c09b0134d4cd80 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include +#include +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, GmemLayout gmem_layout, SmemLayout smem_layout, + CUTE_GRID_CONSTANT CopyAtom const tma, CTA_Tiler cta_tiler, Cluster_Size cluster_size) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + Tensor gA = zipped_divide(mA, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + Tensor gB = zipped_divide(mB, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + +#if 1 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Prepare the TMA_LOAD + // + + Tensor sA_x = make_tensor(sA.data(), make_layout(sA.layout(), Layout<_1>{})); // ((CTA_TILE_M,CTA_TILE_N,...),_1) + Tensor tBgB = gB; // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + auto [tAgA, tAsA] = tma_partition(tma, cta_rank_in_cluster, make_layout(cluster_size), sA_x, gA); + +#if 1 + if (thread0()) { + print("sA_x : "); print(sA_x); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // TMA Multicast Masks -- Get a mask of the active ctas in each TMA + // + + + int elected_cta_rank = 0; + bool elect_one_cta = (elected_cta_rank == cta_rank_in_cluster); + bool elect_one_thr = cute::elect_one_sync(); + + uint16_t tma_mcast_mask = ((uint16_t(1) << cluster_size) - 1); + +#if 1 + if (thread0()) { + print("tma_mcast_mask : "); print(tma_mcast_mask); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Perform the TMA_LOAD + // + + if (elect_one_thr) { + // Initialize TMA barrier + cute::initialize_barrier(tma_load_mbar[0], /* num_threads */ 1); + } + int tma_phase_bit = 0; + // Ensures all CTAs in the Cluster have initialized + __syncthreads(); + cute::cluster_sync(); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); + + if (elect_one_thr) + { + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0], tma_mcast_mask), tAgA(_,stage), tAsA(_,0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from tma_phase_bit value + cute::wait_barrier(tma_load_mbar[0], tma_phase_bit); + tma_phase_bit ^= 1; + + // + // Write out trivially smem -> gmem + // + + // Subbyte elements could cause race conditions, so be even more conservative + if (elect_one_cta && elect_one_thr) { + copy(sA, tBgB(_,stage)); + } + + __syncthreads(); + cute::cluster_sync(); + } +} + +template +auto +test_tma_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = make_tma_atom(copy_op, gA, smem_layout, cta_tiler, cluster_size); + //print(tma); + + // Launch + + dim3 dimBlock(32); + dim3 dimCluster(size(cluster_size)); + dim3 dimGrid = dimCluster; + int smem_size = sizeof(SharedStorage); + + void* kernel_ptr = (void*) &tma_test_device_cute; + + cutlass::launch_kernel_on_cluster({dimGrid, dimBlock, dimCluster, smem_size}, + kernel_ptr, + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast(raw_pointer_cast(d_out.data())), + gmem_layout, + smem_layout, + tma, cta_tiler, cluster_size); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0429d2435fbf43c690f311c1f7c04f7025a2dd94 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, + CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor mB = tma.get_tma_tensor(shape(gmem_layout)); + + constexpr int R = rank_v; + Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + + // + // Prepare the TMA_STORE + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + Tensor tBsB_x = cta_tma.partition_S(sB); // (TMA,TMA_M,TMA_N) + Tensor tBgB_x = cta_tma.partition_D(gB); // (TMA,TMA_M,TMA_N,REST_M,REST_N) + +#if 0 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mB : "); print( mB.data()); print(" o "); print( mB.layout()); print("\n"); + print(" gB : "); print( gB.data()); print(" o "); print( gB.layout()); print("\n"); + print("tBgB_x: "); print(tBgB_x.data()); print(" o "); print(tBgB_x.layout()); print("\n"); + print(" sB : "); print( sB.data()); print(" o "); print( sB.layout()); print("\n"); + print("tBsB_x: "); print(tBsB_x.data()); print(" o "); print(tBsB_x.layout()); print("\n"); + } +#endif + + // + // Perform the TMA_STORE + // + + // INPUT: Group the CTA_TILE_X modes and REST_X modes for input + Tensor tAgA = group_modes<0,R>(group_modes(gA)); // (CTA_TILE, REST) + + // OUTPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles + Tensor tBgB = group_modes<1,rank(tBgB_x)>(tBgB_x); // (TMA,REST) + Tensor tBsB = group_modes<1,rank(tBsB_x)>(tBsB_x); // (TMA,REST) + static_assert(size<1>(tBsB) == 1); + +#if 0 + if (thread0()) { + print("tAgA : "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); + print("tBsB : "); print(tBsB.data()); print(" o "); print(tBsB.layout()); print("\n"); + print("tBgB : "); print(tBgB.data()); print(" o "); print(tBgB.layout()); print("\n"); + } +#endif + + // Test L2 prefetch + cooperative_prefetch<128>(threadIdx.x, gA); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tBgB); ++stage) + { + // + // Read in trivially gmem -> smem + // + // Subbyte elements could cause race conditions, so be even more conservative + if (thread0()) { + copy(tAgA(_,stage), sB); + } + + __syncthreads(); + cute::cp_async_wait<0>(); + + // + // Perform the TMA_STORE + // + + if (threadIdx.x == 0) { + copy(tma, tBsB(_,0), tBgB(_,stage)); + } + + tma_store_wait<0>(); + __syncthreads(); + } +} + +template +void +test_tma_store(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_out.data())), gmem_layout); + auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); + //print(tma); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tma, cta_tile, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..3163a0d0eaa24513ee210bd2b310d1bf233773a9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h @@ -0,0 +1,417 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_with_reduction_threadblock( + typename Epilogue::ElementVector *ptr_Reduction, + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + typename Epilogue::TensorTileIterator::Params params_Tensor, + typename Epilogue::TensorTileIterator::Element *ptr_Tensor, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::TensorTileIterator iterator_T( + params_Tensor, + ptr_Tensor, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + accumulator_iterator.load(accumulators); + +#if 0 + // For debugging, enable this block of code to fill each accumulator element with its + // source thread ID. + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < accumulators.size(); ++i) { + typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); + accumulators[i] = x; + } + + __syncthreads(); + +#endif + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue(output_op, ptr_Reduction, iterator_D, accumulators, iterator_C, iterator_T); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpilogueWithReductionTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementTensor = typename Epilogue::TensorTileIterator::Element; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensor accumulator_tensor; + cutlass::HostTensor source_tensor; + cutlass::HostTensor output_tensor; + cutlass::HostTensor additional_tensor; + cutlass::HostTensor reduction_tensor; + + +public: + + // + // Methods + // + + EpilogueWithReductionTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + additional_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + reduction_tensor({1, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + + cutlass::reference::host::TensorFill(additional_tensor.host_view(), ElementTensor(1)); + } + + bool run_all() { + + /* + double alpha_values[] = {1, 0, 2.25}; + double beta_values[] = {0, 1, -1.25}; + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + int m = quantized_size.row() - m_idx * 3; + int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; + + for (double const &alpha : alpha_values) { + for (double const &beta : beta_values) { + + bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); + + if (!passed) { + return false; + } + } + } + } + } + return true; + */ + + double alpha = 1; + double beta = 0; + + return run( + {quantized_size.row(), quantized_size.column()}, + {cutlass::from_real(alpha), cutlass::from_real(beta)}); + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ElementOutput default_output = ElementOutput(-127); + ElementAccumulator default_reduction = ElementAccumulator(); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + cutlass::reference::host::TensorFill(reduction_tensor.host_view(), default_reduction); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + additional_tensor.sync_device(); + reduction_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); + typename Epilogue::TensorTileIterator::Params params_T(additional_tensor.device_ref().layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_with_reduction_threadblock<<< grid, block >>>( + reduction_tensor.device_data(), + params_D, + output_tensor.device_data(), + params_C, + source_tensor.device_data(), + params_T, + additional_tensor.device_data(), + output_params, + problem_size, + accumulator_tensor.device_view()); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + reduction_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + // + // The output has two parts: + // - GEMM tensor epilogue in canonical layout + // - partial reduction in canonical row-major layout + // + + // Verify the GEMM tensor output + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ElementOutput got = output_tensor.at(coord); + + ElementOutput expected; + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + expected = ElementOutput(output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ElementCompute(source_tensor.at(coord))); + } + else { + expected = default_output; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // Verify the partial reduction + for (int c = 0; c < quantized_size.column(); ++c) { + + ElementAccumulator reduction_acc = ElementAccumulator(); + + for (int r = 0; r < quantized_size.row(); ++r) { + reduction_acc += accumulator_tensor.at({r, c}); + } + + ElementAccumulator expected = default_reduction; + ElementAccumulator got = reduction_tensor.at({0, c}); + + if (c < problem_size.column()) { + expected = reduction_acc; + } + else { + expected = default_reduction; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - reduction element (" << c << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + } + } + + // + // Report results on error + // + + if (errors) { + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e2457fdb4817e1dfb3af73149ae1e4c4458670a2 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h @@ -0,0 +1,356 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for epilogues +*/ +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/platform/platform.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_threadblock( + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + accumulator_iterator.load(accumulators); + +#if 0 + // For debugging, enable this block of code to fill each accumulator element with its + // source thread ID. + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < accumulators.size(); ++i) { + typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); + accumulators[i] = x; + } + + __syncthreads(); + +#endif + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue(output_op, iterator_D, accumulators, iterator_C); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpilogueTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensor accumulator_tensor; + cutlass::HostTensor source_tensor; + cutlass::HostTensor output_tensor; + +public: + + // + // Methods + // + + EpilogueTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 2, + -2, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 2, + -2, + 0); + } + + bool run_all() { + + double alpha_values[] = {1, 0, 2.25}; + double beta_values[] = {0, 1, -1.25}; + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + int m = quantized_size.row() - m_idx * 3; + int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; + + for (double const &alpha : alpha_values) { + for (double const &beta : beta_values) { + + bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); + + if (!passed) { + return false; + } + } + } + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ElementOutput default_output = ElementOutput(-127); + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_threadblock<<< grid, block >>>( + params_D, + output_tensor.device_data(), + params_C, + source_tensor.device_data(), + output_params, + problem_size, + accumulator_tensor.device_view()); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ElementOutput got = output_tensor.at(coord); + + ElementOutput expected; + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + ElementCompute intermediate = + output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ElementCompute(source_tensor.at(coord)); + + if ((cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || std::numeric_limits::is_integer) + && !std::numeric_limits::is_integer) { + std::fesetround(FE_TONEAREST); + expected = ElementOutput(std::nearbyint(float(cutlass::real(intermediate)))); + } else { + expected = ElementOutput(intermediate); + } + } else { + expected = default_output; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) + << ", accum: " << (accumulator_tensor.at(coord)) + << ", source: " << OutputIO(source_tensor.at(coord)) + << ", alpha: " << (output_params.alpha) + << ", beta: " << (output_params.beta) << "\n"; + + ++errors; + } + } + } + + // + // Report results on error + // + + if (errors) { + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..a76578f7638ac1d30161a9bcb55ecec70b5c43e0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h @@ -0,0 +1,394 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_planar_complex_threadblock( + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + int64_t imaginary_stride_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + int64_t imaginary_stride_C, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int64_t imaginary_stride_accum, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D_real( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_D_imag( + params_D, + ptr_D + imaginary_stride_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C_real( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_C_imag( + params_C, + ptr_C + imaginary_stride_C, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + // + // Load accumulators + // + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + + accumulator_iterator.load(accumulators.real); + accumulator_iterator.load_with_pointer_offset(accumulators.imag, imaginary_stride_accum); + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop so assembly is clearly visible + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue( + output_op, + iterator_D_real, + iterator_D_imag, + accumulators, + iterator_C_real, + iterator_C_imag); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpiloguePlanarComplexTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + + using ComplexElementOutput = cutlass::complex; + using ComplexElementAccumulator = cutlass::complex; + using ComplexElementCompute = cutlass::complex; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensorPlanarComplex accumulator_tensor; + cutlass::HostTensorPlanarComplex source_tensor; + cutlass::HostTensorPlanarComplex output_tensor; + +public: + + // + // Methods + // + + EpiloguePlanarComplexTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + #if 1 + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + #else + + cutlass::reference::host::BlockFillSequential(accumulator_tensor.host_data(), accumulator_tensor.capacity()); + + #endif + } + + bool run_all() { + + cutlass::complex alpha_values[3]; + + alpha_values[0] = cutlass::complex(1, 0); + alpha_values[1] = cutlass::complex(0, 0); + alpha_values[2] = cutlass::complex(2.25f, -0.5f); + + cutlass::complex beta_values[3]; + + beta_values[0] = cutlass::complex(0, 0); + beta_values[1] = cutlass::complex(1, 0); + beta_values[2] = cutlass::complex(0.5f, -2.25f); + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + cutlass::MatrixCoord problem_size( + quantized_size.row() - m_idx * 3, + quantized_size.column() - n_idx * Epilogue::kElementsPerAccess + ); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = run(problem_size, {alpha, beta}); + + if (!passed) { + return false; + } + } + } + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ComplexElementOutput default_output = ComplexElementOutput(ElementOutput(-127), ElementOutput(-101)); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_planar_complex_threadblock<<< grid, block >>>( + params_D, + output_tensor.device_data(), + output_tensor.imaginary_stride(), + params_C, + source_tensor.device_data(), + source_tensor.imaginary_stride(), + output_params, + problem_size, + accumulator_tensor.device_view_real(), + accumulator_tensor.imaginary_stride() + ); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ComplexElementOutput got = output_tensor.at(coord); + + ComplexElementOutput expected = default_output; + + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + ComplexElementOutput src = source_tensor.at(coord); + + ComplexElementCompute tmp = + output_params.alpha * ComplexElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ComplexElementCompute(src.real(), src.imag()); + + expected = ComplexElementOutput(ElementOutput(tmp.real()), ElementOutput(tmp.imag())); + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // + // Report results on error + // + + if (errors) { + + + std::cout << "Incorrect result for problem(" + << problem_size.row() << ", " + << problem_size.column() << ") for alpha: " << output_params.alpha << ", beta: " << output_params.beta << std::endl; + + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + + std::cout << "Wrote workspace to '" << ss.str() << "'" << std::endl; + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0054a1b6757a232e9177407fdd2041b6a91cffb9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp @@ -0,0 +1,1384 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/layout.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +namespace cutlass { +namespace gemm { +namespace device { +using namespace cute; + +// This type is only intended to demonstrate porting 2.x kernels to 3.0 +template< + class OperatorClass, class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types { + static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); +}; + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct DefaultGemm_TensorOpSm80_OperandA; + +template +struct DefaultGemm_TensorOpSm80_OperandB; + +// +// F16: 128-by-128-by-64 +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// +// F16: 128-by-128-by-32 (small k-block) +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,3,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere MMA F32F16 +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + half_t, LayoutA, + half_t, LayoutB, + float, LayoutC, + float> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32,_32,_16>>; // 32x32x16 MMA for LDSM, 1x2x1 value group + + // A + static constexpr int kAlignmentA = 8; + using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< + half_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 8; + using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< + half_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + half_t, TagToStrideA_t, + half_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + float, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// +// TF32: 128-by-128-by-kblock (kBlock = 16, 32) +// + +/// Operand A - Row-major (K-major) (kBlock = 32) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,2,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Row-major (K-major) (kBlock = 16) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,2,3>{}, + Layout, + Stride<_16, _1>>{})); + using SmemCopyAtom = Copy_Atom; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,2,3>{}, + Layout, + Stride< _1,_32>>{})); + using SmemCopyAtom = Copy_Atom, tfloat32_t>; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere MMA F32TF32 +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + tfloat32_t, LayoutA, + tfloat32_t, LayoutB, + float, LayoutC, + float> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout, Stride<_2, _1, _1>>, // 2x2x1 thread group + Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group + + // A + static constexpr int kAlignmentA = 4; + using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< + tfloat32_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 4; + using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< + tfloat32_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + tfloat32_t, TagToStrideA_t, + tfloat32_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + float, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + int32_t, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _64>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32,_32,_32>>; // 16x16x32 MMA for LDSM, 1x2x1 value group + + // A (M,K) K-major + using SmemLayoutAtomA = decltype( + composition( + Swizzle<2,4,3>{}, + Layout, + Stride<_64, _1>>{})); + static constexpr int kAlignmentA = 16; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>>{})); + // LDS.32- or LDSM-based copy atom + // using SmemCopyAtomA = Copy_Atom; + using SmemCopyAtomA = Copy_Atom; // LDSM works + + // B (N,K) K-major + using SmemLayoutAtomB = decltype( + composition( + Swizzle<2,4,3>{}, + Layout, + Stride<_64, _1>>{})); + static constexpr int kAlignmentB = 16; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>>{})); + + // LDS.32- or LDSM-based copy atom + // using SmemCopyAtomB = Copy_Atom; + using SmemCopyAtomB = Copy_Atom; // LDSM works + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + int8_t, TagToStrideA_t, + int8_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + int32_t, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// SIMT TWO STAGE /////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct DefaultGemm_Simt_OperandA; + +/////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultGemm_Simt_OperandA +{ + using SmemLayoutAtom = Layout, + Stride< _1,_128>>; + + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); +}; + +template +struct DefaultGemm_Simt_OperandA +{ + using SmemLayoutAtom = Layout, + Stride< _1,Int<128 + 4>>>; // Padded + + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + +}; + +template +struct DefaultGemm_Simt_OperandB; + +template +struct DefaultGemm_Simt_OperandB + : DefaultGemm_Simt_OperandA {}; + +template +struct DefaultGemm_Simt_OperandB + : DefaultGemm_Simt_OperandA {}; + +} // end namespace detail + +// SIMT Two Stage +template < + class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _8>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>>; + + // A + static constexpr int kAlignmentA = 1; + using DefaultOperandA = detail::DefaultGemm_Simt_OperandA; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 1; + using DefaultOperandB = detail::DefaultGemm_Simt_OperandB; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + + +// +// DP4A - int8 Proof-of-concept +// + +// SIMT Two Stage TN - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + // NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; // Tile of atoms (threads) + + // A (M,K) K-major + using ElementA = int8_t; + // 40% from regular M and N major layout + // using SmemLayoutAtomA = Layout, + // Stride< _1,_128>>; + // 80% from interleaved layouts + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 4; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) K-major + using ElementB = int8_t; + // 40% from regular M and N major layout + // using SmemLayoutAtomB = Layout, + // Stride< _1,_128>>; + // 80% from interleaved layouts + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 4; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage NN - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + + using DispatchPolicy = MainloopSm70TwoStage; + + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) M-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // B (N,K) K-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 4; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage NT - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::RowMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) M-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // B (N,K) N-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage TT - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::RowMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) K-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 4; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) N-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// SIMT MULTI STAGE ////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage NT +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>, // 32x32x1 MMA with perm for load vectorization + Layout,Stride<_2,_1>>,Underscore>>; + + // A (M,K) M-major + using SmemLayoutAtomA = Layout>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout>{}, + Layout>{})); + + // B (N,K) N-major + using SmemLayoutAtomB = Layout>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage TN +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>>; + + // A (M,K) K-major + using SmemLayoutAtomA = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride<_16, _1>>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride<_16, _1>>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage NN +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>,Underscore,Underscore>>; // 32x16x1 MMA with perm for load vectorization + + // A (M,K) M-major + using SmemLayoutAtomA = Layout>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout>{}, + Layout>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride<_16, _1>>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage TT +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>,Underscore>>; // 16x32x1 MMA with perm for load vectorization + + // A (M,K) K-major + using SmemLayoutAtomA = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride<_16, _1>>{})); + + // B (N,K) N-major + using SmemLayoutAtomB = Layout>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA TN (K-Major A and K-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) K-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // B (N,K) K-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; + +/* + using EpilogueOutputOp = epilogue::collective::Epilogue< + epilogue::thread::LinearCombination, + Layout, + Stride< _1,_64>>, // SMEM layout + Copy_Atom,double>, // R2S with tiled_mma layout + decltype(make_tiled_copy(Copy_Atom,double>{},// S2R + Layout, + Stride< _1,_16>>{}, // Thread layout + Layout>{})), // Value layout + Copy_Atom,double> // R2G with S2R_dst layout + >; +*/ +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA NN (M-Major A and K-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) M-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // B (N,K) K-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{}));// N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA NT (M-Major A and N-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) M-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // B (N,K) N-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA TT (K-Major A and N-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) K-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // B (N,K) N-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Hopper fp64 MMA TN +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm90, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) K-major + using SmemLayoutAtomA = decltype( + make_ordered_layout(Shape<_128,_16>{}, + Step < _2, _1>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = decltype( + make_ordered_layout(Shape<_64,_16>{}, + Step < _2, _1>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + double, double, + double, cutlass::layout::ColumnMajor, 1, + double, cutlass::layout::ColumnMajor, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89755dd7d3162b114a537e58c6aa33cac80078f9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -0,0 +1,3993 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include // std::lcm + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/complex.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/detail/collective.hpp" + +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" +#include "cute/layout.hpp" +#include "cute/numeric/int.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class ScalarLoc { + ON_HOST = 0, + ON_DEVICE = 1 +}; + +enum class VectorScale { + DISABLED = 0, + ENABLED = 1 +}; + +enum class CheckEquality { + EXACT = 0, + RELATIVE = 1 +}; + +namespace detail { + +inline constexpr auto decomp_mode_to_string = + [] (cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode mode) -> std::string { + using Mode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + if (mode == Mode::Heuristic) { + return "Heuristic"; + } + else if (mode == Mode::DataParallel) { + return "DataParallel"; + } + else if (mode == Mode::SplitK) { + return "SplitK"; + } + else if (mode == Mode::StreamK) { + return "StreamK"; + } + else { + return "Unknown"; + } + }; + +inline constexpr auto raster_order_to_string = + [] (cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions mode) -> std::string { + using Mode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + if (mode == Mode::Heuristic) { + return "Heuristic"; + } + else if (mode == Mode::AlongM) { + return "AlongM"; + } + else if (mode == Mode::AlongN) { + return "AlongN"; + } + else { + return "Unknown"; + } + }; + +// Helper classes that take default data type when +// the Gemm::EpilogueOutputOp does not have ElementCompute +// and ElementScalar. +// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) +template +struct ElementComputeType { + using Type = Default; +}; + +template +struct ElementComputeType>> { + using Type = typename Gemm::EpilogueOutputOp::ElementCompute; +}; + +template +struct ElementScalarType { + using Type = Default; +}; + +template +struct ElementScalarType>> { + using Type = typename Gemm::EpilogueOutputOp::ElementScalar; +}; + + +template +struct IsF8F6F4Kernel { + static constexpr bool value = false; +}; + +template +struct IsF8F6F4Kernel> { + static constexpr bool value = true; +}; + + +template +struct IsSfdEpi : cute::false_type {}; + +template +struct IsSfdEpi> : cute::true_type {}; + +// The maximum swizzle size to use +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +class MaxSwizzleSize { +public: + MaxSwizzleSize() = default; + + template && + !cute::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +template +struct IsDefaultEpilogue { + static constexpr bool value = false; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsLegacyEpiloguePolicy { + static constexpr bool value = false; +}; + +template +struct IsLegacyEpiloguePolicy> { + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool value = cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize>>; +}; + +// The number of splits to test. +// +// This class makes it harder to confuse the order of arguments +// of the various run(...) functions in this file. The constructor +// is explicit, so one can't just type 42 (or false, which the +// compiler unhelpfully turns into 0); one has to type Splits(42). +// Splits() picks the default number of splits, 1. +// +// The conversion-to-int operator (operator int()) MUST be explicit! +// Conversion to int MUST require static_cast. +// Otherwise, that defeats a key purpose of this class, +// which is to catch common errors of confusing the order +// of function arguments. +class Splits { +public: + Splits() = default; + + template && + !cute::is_same_v)) > + explicit Splits(IntegralNotBool splits) : splits_(splits) {} + explicit operator int() const { return splits_; } +private: + int splits_ = 1; +}; + +// The number of iterations to test. +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +// Iterations() picks the default number of iterations, 20. +class Iterations { +public: + Iterations() = default; + + template && + !cute::is_same_v)) > + explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} + explicit operator int() const { return iterations_; } +private: + int iterations_ = 20; +}; + +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + + else if (bits_input <= 8) { + + if constexpr ( + cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + + scope_max = 1; + scope_min = -1; + + } + + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Looks at Cute Stride to check Row / Column Major +template +static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); +} + + +// +// Default MMA input Operands : A , B +// +template< + class ScheduleType_, + class Gemm, + class ElementA_ = typename Gemm::GemmKernel::ElementA, + class ElementB_ = typename Gemm::GemmKernel::ElementB, + class Enable = void> +struct HostCollectiveMainloop { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; + + StrideA stride_a; + StrideB stride_b; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_), + check_relative_equality(check_relative_equality_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (generic)::initialize(problem_shape)"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M * L, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N * L); + + try { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.resize"); +#endif + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.resize"); +#endif + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor A or B resize threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor A or B resize threw an unknown exception"); + throw; + } + + try { + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + } + catch (cutlass::cuda_exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked initialize_tensor threw cutlass::cuda_exception: " << e); + throw; + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked initialize_tensor threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked_initialize_tensor threw an unknown exception"); + throw; + } + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: Check last error before sync_device()"); + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: cudaGetLastError() is " << error_str); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.host_data()=" << tensor_A.host_data() << ", tensor_A.device_data()=" << tensor_A.device_data()); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.host_data()=" << tensor_B.host_data() << ", tensor_B.device_data()=" << tensor_B.device_data()); + } +#endif + try { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.sync_device"); +#endif + tensor_A.sync_device(); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.sync_device"); +#endif + tensor_B.sync_device(); + } + catch (cutlass::cuda_exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw cutlass::cuda_exception: " << e); + throw; + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw an unknown exception"); + throw; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: Reached end"); +#endif + return true; + } + + Arguments to_args() { + + + // Runtime datatype selection + if constexpr (not cute::is_same_v) { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A.device_data()), stride_a, + reinterpret_cast(tensor_B.device_data()), stride_b + }; + } + else { + + Arguments arguments = + { + tensor_A.device_data(), stride_a, tensor_B.device_data(), stride_b + }; + return arguments; + } + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + + + auto dummy_SFA = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto dummy_SFB = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + + cutlass::reference::host::GettMainloopParams mainloop_params{}; + + mainloop_params.A = A; + mainloop_params.B = B; + mainloop_params.transform_A = TransformA; + mainloop_params.transform_B = TransformB; + + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view(); + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + + bool passed = true; + return passed; + } +}; + +// +// Sparse MMA host implementation +// +template< + class Gemm, + class ElementA_, + class ElementB_> +struct HostCollectiveMainloopSparse +{ + + // Kernel data types + using ElementA = ElementA_; + // CuTe layout A for the kernel's sparse tensorA. + using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + // CuTe layout E for the kernel's metadata tensor. + using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + // The following typenames are for the reference host tensors. They are non-sparse tensors. + using LayoutTagA = decltype(SparseConfig::deduce_layoutA_tag(LayoutA{})); + using StrideA = cutlass::gemm::TagToStrideA_t; + // We don't care about the actual strideE for the host tensor, but just need one to allocate memory. + using StrideE = StrideA; + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + using LayoutTagE = cutlass::detail::StrideToLayoutTagA_t; + + using ArchTag = typename Gemm::ArchTag; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig>; + + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig, + ArchTag>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + StrideA stride_a; + StrideA stride_a_compressed; + StrideB stride_b; + StrideE stride_e; + + LayoutA layout_a; + LayoutE layout_e; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + typename LayoutTagE::Stride stride_factor_E; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_E; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr int MaxSmCount = 16; + + HostCollectiveMainloopSparse( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride(), + typename LayoutTagE::Stride stride_factor_E_ = typename LayoutTagE::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_E(stride_factor_E_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloopSparse::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_a_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); + stride_e = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); + + auto a_coord = cutlass::make_Coord(M * L, K); + auto b_coord = cutlass::make_Coord(K, N * L); + auto e_coord = cutlass::make_Coord(MAlignedE * L, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * L, KAlignedAC); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_Comp.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_E.sync_device(); + tensor_A_Comp.sync_device(); + + cutlass::Status status {cutlass::Status::kSuccess }; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {M, N, K, L}, + {tensor_A.device_data(), + stride_a, + tensor_A_Comp.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + layout_a = SparseConfig::fill_layoutA(problem_shape_MNKL); + layout_e = SparseConfig::fill_layoutE(problem_shape_MNKL); + + tensor_E.sync_host(); + tensor_A_Comp.sync_host(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A_Comp.device_data()), layout_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_E.device_data(), layout_e + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + return true; + } +}; + +template< + class ScheduleType_, + class Gemm, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + typename Gemm::CollectiveMainloop::DispatchPolicy>>> + : HostCollectiveMainloopSparse +{ + using HostCollectiveMainloopSparse::HostCollectiveMainloopSparse; +}; + +// +// Sparse MMA input Operands : A_compressed, B, metadata +// +// Structured Sparse Gemm Input Operands + +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + typename ElementA_, + typename ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> + : HostCollectiveMainloopSparse +{ + using HostCollectiveMainloopSparse::HostCollectiveMainloopSparse; +}; + +// +// Sparse Gemm Input Operands : A , B, E +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_ >; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +// +// Sparse Gemm Input Operands : A , B, E +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_ >; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + StrideA stride_a; + StrideB stride_b; + + LayoutSFA layout_sfa; + LayoutSFB layout_sfb; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_SFA; + cutlass::HostTensor tensor_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelTmaWarpSpecializedBlockScaledSm100)::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M * L, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N * L); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + tensor_A.sync_device(); + tensor_B.sync_device(); + + using namespace cute; + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + layout_sfb = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); + EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_SFA.host_view().at({0, 0}) = ElementSF(1); + tensor_SFB.host_view().at({0, 0}) = ElementSF(1); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A.device_data()), stride_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_SFA.device_data(), layout_sfa, + tensor_SFB.device_data(), layout_sfb + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); + + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); + + cutlass::reference::host::GettMainloopParams + mainloop_params{A, SfA, B, SfB}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nSFA =\n" << tensor_SFA.host_view() + << "\nSFB =\n" << tensor_SFB.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); + return true; + } +}; + + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Structured Sparse Gemm Input Operands : A_compressed, B, metadata, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + typename ElementA_, + typename ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + // CuTe layout A for the kernel's sparse tensorA. + using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + // CuTe layout E for the kernel's metadata tensor. + using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + // The following typenames are for the reference host tensors. They are non-sparse tensors. + using LayoutTagA = decltype(SparseConfig::deduce_layoutA_tag(LayoutA{})); + using StrideA = cutlass::gemm::TagToStrideA_t; + // We don't care about the actual strideE for the host tensor, but just need one to allocate memory. + using StrideE = StrideA; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using LayoutTagE = cutlass::detail::StrideToLayoutTagA_t; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig, + cutlass::arch::Sm100>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + StrideA stride_a; + StrideA stride_a_compressed; + StrideB stride_b; + StrideE stride_e; + + LayoutA layout_a; + LayoutE layout_e; + LayoutSFA layout_sfa; + LayoutSFB layout_sfb; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + typename LayoutTagE::Stride stride_factor_E; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_SFA; + cutlass::HostTensor tensor_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride(), + typename LayoutTagE::Stride stride_factor_E_ = typename LayoutTagE::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_E(stride_factor_E_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelSparseTmaWarpSpecializedBlockScaledSm100)::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_a_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); + stride_e = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); + + auto a_coord = cutlass::make_Coord(M * L, K); + auto b_coord = cutlass::make_Coord(K, N * L); + auto e_coord = cutlass::make_Coord(MAlignedE * L, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * L, KAlignedAC); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_Comp.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_E.sync_device(); + tensor_A_Comp.sync_device(); + + cutlass::Status status {cutlass::Status::kSuccess }; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {M, N, K, L}, + {tensor_A.device_data(), + stride_a, + tensor_A_Comp.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + layout_a = SparseConfig::fill_layoutA(problem_shape_MNKL); + layout_e = SparseConfig::fill_layoutE(problem_shape_MNKL); + + tensor_E.sync_host(); + tensor_A_Comp.sync_host(); + + using namespace cute; + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + layout_sfb = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); + EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_SFA.host_view().at({0, 0}) = ElementSF(1); + tensor_SFB.host_view().at({0, 0}) = ElementSF(1); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A_Comp.device_data()), layout_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_E.device_data(), layout_e, + tensor_SFA.device_data(), layout_sfa, + tensor_SFB.device_data(), layout_sfb + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); + + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); + + // return {A, SfA, B, SfB}; + cutlass::reference::host::GettMainloopParams + mainloop_params{A, SfA, B, SfB}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nSFA =\n" << tensor_SFA.host_view() + << "\nSFB =\n" << tensor_SFB.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); + return true; + } +}; + +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +template +struct HostCollectiveDefaultEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + + using FusionOp = typename Gemm::EpilogueOutputOp; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + StrideC stride_c; + StrideD stride_d; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + cutlass::HostTensor tensor_C; + // Inputs + ElementScalar alpha; + ElementScalar beta; + + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveDefaultEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize(problem_size, alpha, beta)"); +#endif + // Initialize Epilogue tensors + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M * L, N); + try { + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: resizing tensors threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: resizing tensors threw an unknown exception"); + throw; + } + { + const bool init_succeeded = initialize_tensor(tensor_C.host_view(), init_C, seed + 2020); + if (not init_succeeded) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: initialize_tensor returned false"); + } + EXPECT_TRUE(init_succeeded); + } + tensor_C.host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + try { + tensor_C.sync_device(); + tensor_D.sync_device(); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: sync_device() threw an unknown exception"); + throw; + } + + alpha = alpha_; + beta = beta_; + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { + auto [M, N, K, L] = problem_shape_MNKL; + + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } + + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + } + + bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D)> + epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha; + epilogue_params.beta = beta; + + return epilogue_params; + } +}; + +template +struct HostCollectiveEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + + // + // FusionOperation derived types/queries + // + static constexpr bool IsLegacy = detail::IsLegacyEpiloguePolicy::value; + + // FFMA2 SGEMM uses ThreadEpilogueOp for bias and relu support instead of FusionOp, so we compose LinCombPerRowBiasEltAct FusionOp by hand to test the functionality. + static constexpr bool IsFfma2Kernel = cute::is_same_v; + using FusionOp = cute::conditional_t, + typename Gemm::EpilogueOutputOp>; + static_assert(cute::is_base_of_v); + + + // Scale factor Generation related + using SfStrategy = cutlass::reference::host::SfStrategy; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; + static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; + static constexpr bool IsKMajorSFD = cute::is_same_v; + using ElementSFD = non_void_t; + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; + using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; + cutlass::HostTensor tensor_SFD; + cutlass::HostTensor reference_SFD; + + using ElementCompute = typename FusionOp::ElementCompute; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementBias = non_void_t; + using ElementAux = non_void_t; + using ElementAmax = non_void_t; + using LayoutTagAux = non_void_t; + using ActivationFunctor = non_void_t>; + + static constexpr bool IsRowBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsColBiasEnabled = FusionOp::IsPerColBiasSupported; + static_assert(not (IsColBiasEnabled && IsRowBiasEnabled)); + + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; + static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsPerColScaleEnabled = FusionOp::IsPerColScaleSupported; + static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; + static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; + static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + StrideC stride_c; + StrideD stride_d; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + cutlass::HostTensor alpha; + cutlass::HostTensor beta; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor bias; + cutlass::HostTensor tensor_C; + cutlass::HostTensor norm_constant; + + // Outputs + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor tensor_Aux; + cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // References + cutlass::HostTensor reference_dbias; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If vector scale is supported and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors + cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; + // Random distribution with which to initialize the bias vector + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize(problem_size, alpha, beta)"); +#endif + // Initialize Epilogue tensors + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M * L, N); + try { + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: resizing tensors threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: resizing tensors threw an unknown exception"); + throw; + } + + try { + const bool initialize_tensor_C_succeeded = + initialize_tensor(tensor_C.host_view(), init_C, seed + 2020); + if (not initialize_tensor_C_succeeded) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor returned false"); + } + EXPECT_TRUE(initialize_tensor_C_succeeded); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor threw an unknown exception"); + throw; + } + + tensor_C.host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + try { + tensor_C.sync_device(); + tensor_D.sync_device(); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: sync_device() threw an unknown exception"); + throw; + } + + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); + auto row_vector_coord = cutlass::make_Coord(N); + auto batch_vector_coord = cutlass::make_Coord(L); + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { + // scalars + if (vector_scale_mode == VectorScale::DISABLED) { + // batched scalars + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + alpha.resize(batch_vector_coord, true); + beta.resize(batch_vector_coord, true); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (beta_ != ElementScalar(0)) { + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + else { + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + // non-batched scalars + else { + alpha.resize(scalar_coord, false); + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + // batched vectors + else { + auto batched_vector_coord = cutlass::make_Coord((IsPerRowScaleEnabled ? M : N) * L); + alpha.resize(batched_vector_coord, true); + beta.resize(batched_vector_coord, true); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (beta_ != ElementScalar(0)) { + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + else { + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + } + else { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + // Set alpha beta for different batches. + alpha.resize(batch_vector_coord, true); + beta.resize(batch_vector_coord, true); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + for (int l = 0; l < L; ++l) { + beta.host_view().at(cutlass::make_Coord(l)) = beta_ + ElementScalar(l); + } + } + else { + alpha.resize(scalar_coord, false); + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + alpha.sync_device(); + beta.sync_device(); + + if constexpr (IsScaleFactorEnabled) { + scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); + EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); + EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); + EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { + bias.resize(IsRowBiasEnabled ? col_vector_coord : row_vector_coord); + EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); + bias.sync_device(); + } + + if constexpr (IsDeBiasEnabled) { + bias.resize(col_vector_coord); + reference_dbias.resize(col_vector_coord); + cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); + cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabledD) { + abs_max_D.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_D.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_D.sync_device(); + reference_abs_max_D.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); + } + + if constexpr (IsAuxInEnabled) { + auto aux_coord = cutlass::make_Coord(M * L, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensor_Aux.resize(aux_coord, aux_layout); + EXPECT_TRUE(initialize_tensor(tensor_Aux.host_view(), init_C, seed + 2023)); + tensor_Aux.sync_device(); + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); + } + + if constexpr (IsAuxOutEnabled) { + auto aux_coord = cutlass::make_Coord(M * L, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensor_Aux.resize(aux_coord, aux_layout); + reference_Aux.resize(aux_coord, aux_layout, false); + tensor_Aux.sync_device(); + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); + + if constexpr (IsScaleFactorEnabled) { + scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); + scale_Aux.sync_device(); + } + + if constexpr (IsAbsMaxEnabledAux) { + abs_max_Aux.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_Aux.sync_device(); + reference_abs_max_Aux.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); + } + } + + + if constexpr (IsBlockScaleSupported) { + auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); + auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); + auto sfd_coord = [&] () { + if constexpr (IsKMajorSFD) { + return cutlass::make_Coord(m_blks * Blk_MN{} * L, n_blks * Blk_SF{}); + } + else { + return cutlass::make_Coord(m_blks * Blk_SF{} * L, n_blks * Blk_MN{}); + } + }(); + tensor_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D)); + reference_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false); + tensor_SFD.sync_device(); + norm_constant.resize(scalar_coord, true); + EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); + norm_constant.sync_device(); + } + + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } + + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + } + + bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); + if(!passed) { + #if 0 + auto [M, N, K, L] = problem_shape_MNKL; + auto ref = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto comp = cute::make_tensor(detail::make_iterator(tensor_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + for(int i=0; i(ElementD(ref(i, j, l))) != static_cast((ElementD(comp(i, j, l))))) { + printf(" ref: %f comp: %f\n", i, j, l, static_cast(ElementD(ref(i, j, l))), static_cast((ElementD(comp(i, j, l))))); + } + } + } + } + #endif + std::cout<<"D is incorrect"<(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + Arguments arguments = + { + {}, + tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d + }; + + auto &fusion_args = arguments.thread; + if constexpr (IsLegacy) { + arguments.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.ptr_Bias = bias.device_data(); + arguments.ptr_T = tensor_Aux.device_data(); + } + else { + fusion_args.alpha = alpha.at(coord_0); + fusion_args.alpha_ptr = alpha.device_data(); + // Only initializing beta/beta_ptr for non-void source + if constexpr (not cute::is_void_v) { + fusion_args.beta = beta.at(coord_0); + fusion_args.beta_ptr = beta.device_data(); // if vector_scale_mode is true this is nullptr + } + + if constexpr (IsPerRowScaleEnabled) { + int32_t m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + fusion_args.dAlpha = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); + fusion_args.dBeta = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); + } + else if constexpr (IsPerColScaleEnabled) { + int32_t n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + fusion_args.dAlpha = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); + fusion_args.dBeta = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); + } + else { + if constexpr (not IsFfma2Kernel) { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + if (L > 1) { + fusion_args.dAlpha = cute::make_stride(cute::_0{},cute::_0{}, int64_t(1)); + fusion_args.dBeta = cute::make_stride(cute::_0{},cute::_0{}, int64_t(1)); + } + } + } + } + + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } + + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); + } + + // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty + auto init_activation_args = [] (auto activation, auto& args) { + using Activation = cute::remove_cvref_t; + if constexpr (cute::is_same_v>) { + args.lower_bound = 0; // Treat Clamp as ReLU + args.upper_bound = cutlass::platform::identity_for_minimum(); + } + if constexpr (cute::is_same_v>) { + args.scale = ElementCompute(1); + } + }; + + if constexpr (not cute::is_same_v>) { + init_activation_args(ActivationFunctor{}, fusion_args.activation); + } + if constexpr (IsAbsMaxEnabledD) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + } + + if constexpr (IsAuxOutEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabledAux) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } + } + + + if constexpr (IsBlockScaleSupported) { + arguments.thread.block_scale_factor_ptr = tensor_SFD.device_data(); + arguments.thread.norm_constant_ptr = norm_constant.device_data(); + } + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), + cute::make_layout(cute::make_shape(IsRowBiasEnabled ? M : N))); + auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensor_Aux.host_data() : reference_Aux.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_Aux)); + auto Valpha = [&](){ + if constexpr (IsPerRowScaleEnabled) { + int m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); + } + else if constexpr (IsPerColScaleEnabled) { + int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); + } + else { + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); + } + }(); + + auto Vbeta = [&]() { + if constexpr (IsPerRowScaleEnabled) { + int m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); + } + else if constexpr (IsPerColScaleEnabled) { + int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); + } + else { + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); + } + }(); + + auto SfD = [&](){ + if constexpr (IsBlockScaleSupported) { + auto tensor = make_tensor(detail::make_iterator(reference_SFD.host_data()), + Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + return tensor; + } + else { + // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. + return D; + } + }(); + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + decltype(Bias), + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), + ActivationFunctor, + decltype(SfD), + Int, + cutlass::plus, + IsColBiasEnabled + , SfGenStrategy + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha.at(coord_0); + epilogue_params.beta = beta.at(coord_0); + + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_a = scale_A.at(coord_0); + epilogue_params.scale_b = scale_B.at(coord_0); + epilogue_params.scale_c = scale_C.at(coord_0); + epilogue_params.scale_d = scale_D.at(coord_0); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled or IsDeBiasEnabled) + { + epilogue_params.Bias = Bias; + } + + if constexpr (IsAbsMaxEnabledD) { + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + } + + if constexpr (IsAuxInEnabled) { + epilogue_params.Aux = Aux; + } + + if constexpr (IsAuxOutEnabled) { + epilogue_params.Aux = Aux; + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_aux = scale_Aux.at(coord_0); + } + if constexpr (IsAbsMaxEnabledAux) { + epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); + } + } + + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { + epilogue_params.Valpha = Valpha; + if (vector_scale_mode == VectorScale::ENABLED) { + epilogue_params.Vbeta = Vbeta; + } + } + else { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + epilogue_params.Valpha = Valpha; + epilogue_params.Vbeta = Vbeta; + } + } + + if constexpr (IsBlockScaleSupported) { + epilogue_params.SfD = SfD; + epilogue_params.st = norm_constant.at(coord_0); + } + return epilogue_params; + } +}; + +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* +> +struct TestbedImpl { + // Kernel data types + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type + using HostCollectiveMainloopType = HostCollectiveMainloop; + + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, + HostCollectiveEpilogue>; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; + using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; + using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; + using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + + using InternalElementA = typename Gemm::GemmKernel::ElementA; + using InternalElementB = typename Gemm::GemmKernel::ElementB; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + uint32_t sm_count; + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr uint32_t mma_promotion_interval = 4; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + HostCollectiveMainloopType collective_mma_inputs; + CollectiveEpilogue collective_epilogue; + + // + // Methods + // + + TestbedImpl( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + /// Initializes data structures + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::initialize(problem_size, alpha, beta)"); +#endif + collective_mma_inputs.initialize(problem_size); + collective_epilogue.initialize(problem_size, alpha_, beta_); + + return true; + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) + { + auto [M, N, K, L] = problem_shape_MNKL; + + bool passed = collective_mma_inputs.compare_reference(problem_shape_MNKL); + passed &= collective_epilogue.compare_reference(problem_shape_MNKL, alpha, beta); + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + collective_mma_inputs.print_tensors(file); + collective_epilogue.print_tensors(file); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_size, + ElementScalar alpha, + ElementScalar beta) + { + using namespace cute; + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto mainloop_params = collective_mma_inputs.to_host_args(problem_size); + auto epilogue_params = collective_epilogue.to_host_args(problem_size); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + bool passed = compare_reference(problem_shape_MNKL, alpha, beta); + return passed; + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } + + return true; + } + + bool profile( + ProblemShapeType problem_size, + int iterations, + Gemm& gemm_op, + typename Gemm::Arguments& arguments, + cutlass::device_memory::allocation& workspace) { + int M = cute::size<0>(problem_size); + int N = cute::size<1>(problem_size); + int K = cute::size<2>(problem_size); + int L = 1; + if constexpr(cute::rank(ProblemShapeType{}) == 4) { + L = cute::size<3>(problem_size); + } + + + cutlass::Status status; + // + // Run the GEMM + // + cudaError_t result; + + for (int iter = 0; iter < iterations; ++iter) { + status = gemm_op(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + return false; + } + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + return true; + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + detail::Iterations iterations = detail::Iterations{}, + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} + ) + { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run"); +#endif + + // Fail test if insufficient CUDA device + if (!sufficient()) { + CUTLASS_TRACE_HOST("TestbedImpl::run: Test failed due to insufficient CUDA device"); + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("TestbedImpl::run: sufficient() returned true"); + } +#endif + + try { + const bool initialized = this->initialize(problem_size, alpha, beta); + if (not initialized) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize returned false"); + std::cerr << "Initialization failed \n"; + return false; + } + } + catch ([[maybe_unused]] std::exception const& e) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an unknown exception"); + throw; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize() returned true"); +#endif + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = this->sm_count; + } + else { + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + else { + scheduler_args = { static_cast(max_swizzle), raster_order }; + } + typename HostCollectiveMainloopType::Arguments mainloop_args; + + mainloop_args = collective_mma_inputs.to_args(); + + + if constexpr (IsRuntimeDataType) { + mainloop_args.runtime_data_type_a = runtime_input_datatype_a; + mainloop_args.runtime_data_type_b = runtime_input_datatype_b; + } + + + arguments = + { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + mainloop_args, + collective_epilogue.to_args(problem_size), + hw_info, + scheduler_args + }; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Creating gemm_op"); +#endif + Gemm gemm_op; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling Gemm::get_workspace_size"); +#endif + size_t workspace_size = Gemm::get_workspace_size(arguments); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Allocating workspace of size " << workspace_size); +#endif + cutlass::device_memory::allocation workspace(workspace_size); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.can_implement"); +#endif + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + std::cerr << "This test is not supported: " << error_str << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling profile"); +#endif + return profile(problem_size, static_cast(iterations), gemm_op, arguments, workspace); + } + else { + cudaError_t result; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.initialize"); +#endif + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.run"); +#endif + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling cudaDeviceSynchronize"); +#endif + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaDeviceSynchronize reports non-success"); + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling this->verify"); +#endif + bool passed = this->verify(problem_size, alpha, beta); + if (!passed) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->verify FAILED"); + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta + << "\n"; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->verify passed"); + } +#endif + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Reached end"); +#endif + return passed; + } + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* +> +struct Testbed3x { + + using TestBedImpl = typename detail::TestbedImpl< + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, + ElementB + , RuntimeDatatypeA + , RuntimeDatatypeB + >; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3x( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, + bool profiling = false, + detail::Iterations iterations = detail::Iterations{} + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} + ) + { + return impl_.run( + problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits, decomposition_mode + , runtime_input_datatype_a, runtime_input_datatype_b + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestGemmPerf3x(int iterations = 20) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalar = ElementAccumulator; + bool passed = true; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector problem_size_m = { 4608 }; + std::vector problem_size_n = { 4608 }; + std::vector problem_size_k = { 8192 }; + + Testbed3x testbed; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0), + RasterOrderOptions{}, detail::MaxSwizzleSize(1), detail::Splits{1}, DecompositionMode{}, + true, // profiling + detail::Iterations{iterations}); + + if (!passed) { + return false; + } + } + } + } + + return true; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +template < + typename Gemm, + typename RuntimeDataTypeA, + typename RuntimeDataTypeB, + bool force_legacy_epilogue = false> +bool TestRuntimeDataTypeSmall( + RuntimeDataTypeA runtime_input_datatype_a, + RuntimeDataTypeB runtime_input_datatype_b, + double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, VectorScale vector_scale_mode = VectorScale::ENABLED, std::vector override_problem_size_k = {}) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + + using InternalElementA = typename Gemm::GemmKernel::ElementA; + using InternalElementB = typename Gemm::GemmKernel::ElementB; + + CtaShape_MNK cta_shape; + static constexpr int SmCount = 16; + static constexpr int MultiplierOffsetM = 1; + static constexpr int MultiplierOffsetN = 2; + static constexpr int MultiplierOffsetK = 3; + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + + float waves[] = {0.5, 1.25, 2.5}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + problem_size_k = {256 + max_alignment * MultiplierOffsetK, 512 + max_alignment * MultiplierOffsetK}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + [[maybe_unused]] constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + bool passed = true; + + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + int num_grid = int(wave * SmCount); + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = num_grid / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = num_grid / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment; + int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment; + + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (DecompositionMode decomp_mode : decomposition_modes) { + std::vector problem_splits = {detail::Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(detail::Splits{2}); + } + for (auto splits : problem_splits) { + + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e2m1_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF4Format::E2M1 && + runtime_input_datatype_b == cute::UMMA::MXF4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + else { + std::cout << "Unsupported configuration for runtime datatype MXFP4." << std::endl; + return false; + } + } + + else + if constexpr (cute::is_same_v && + cute::is_same_v) { + static_assert((cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v), + "Runtime datatype must be selected with an appropriate static umbrella data type."); + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e4m3_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + // f6xf4 + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e3m2_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e2m1_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E2M1 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e4m3_e3m2 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E3M2) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e3m2_e2m3 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M3) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e5m2_e5m2 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e4m3_e5m2 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e5m2_e4m3 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e4m3_e4m3 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for runtime datatype."); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // splits + } // decomposition_mode + } // k + } // waves + + return passed; +} + +template +bool TestSmall(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + CtaShape_MNK cta_shape; + Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); + static constexpr int SmCount = 16; + static constexpr int MultiplierOffsetM = 1; + static constexpr int MultiplierOffsetN = 2; + static constexpr int MultiplierOffsetK = 3; + int max_alignment_k = 0; + int max_alignment_m = 0; + int max_alignment_n = 0; + + if constexpr (apply_alignment_offset) { + max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + max_alignment_n = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + max_alignment_m = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + } + // Alignment for SFD + if constexpr (detail::IsSfdEpi::value) { + using GmemLayoutTagScalefactor = typename Gemm::GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::GmemLayoutTagScalefactor; + constexpr int SFDVecSize = Gemm::GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::SFVecSize; + if constexpr (cute::is_same_v) { + max_alignment_n = std::lcm(max_alignment_n, SFDVecSize); + } + else { + max_alignment_m = std::lcm(max_alignment_m, SFDVecSize); + } + } + + float waves[] = {0.5, 1.25, 2.5}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + problem_size_k = {256 + max_alignment_k * MultiplierOffsetK, 512 + max_alignment_k * MultiplierOffsetK}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + bool passed = true; + + std::vector raster_order_options = {RasterOrderOptions::Heuristic}; + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + int num_grid = int(wave * SmCount); + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = num_grid / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = num_grid / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment_m; + int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment_n; + int l = test_batched_alpha_beta && wave == waves[0] && k == problem_size_k[0] ? 2 : 1; // only test the smallest problem size + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, l}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (DecompositionMode decomp_mode : decomposition_modes) { + for (RasterOrderOptions raster_order : raster_order_options) { + std::vector problem_splits = {detail::Splits{1}}; + if constexpr (UsesStreamKScheduler) { + if (decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(detail::Splits{2}); + problem_splits.push_back(detail::Splits{4}); + } + } + for (auto splits : problem_splits) { + try { + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, // raster_order + detail::MaxSwizzleSize(0), + splits, + decomp_mode + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception (unknown)"; + throw; + } + EXPECT_TRUE(passed) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} failed"; + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << l << " FAILED.\n"; + return false; + } + } // splits + } // raster_order + } // decomposition_mode + } // k + } // waves + + return passed; +} + +template +bool TestSmallFusion(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + return TestSmall(alpha, + beta, + check_relative_equality, + use_device_scalars, + vector_scale_mode, + override_problem_size_k); +} + + + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAll(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed(check_relative_equality, ScalarLoc::ON_HOST, VectorScale::DISABLED); + + int max_alignment_m = std::max({Gemm::kAlignmentA, Gemm::kAlignmentC, Gemm::kAlignmentD}); + int max_alignment_n = std::max({Gemm::kAlignmentB, Gemm::kAlignmentC, Gemm::kAlignmentD}); + if constexpr (std::is_base_of_v) { + max_alignment_m = std::max(max_alignment_m, Gemm::EpilogueOutputOp::AlignmentAux); + max_alignment_n = std::max(max_alignment_n, Gemm::EpilogueOutputOp::AlignmentAux); + } + std::vector problem_size_m = {max_alignment_m, 512 - 3 * max_alignment_m}; + std::vector problem_size_n = {max_alignment_n, 512 - 2 * max_alignment_n}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + int max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_k = {max_alignment_k, TileShapeK * (Stages + 1) - max_alignment_k}; + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + std::vector problem_splits = {detail::Splits{1}}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + problem_splits.push_back(detail::Splits{2}); + problem_splits.push_back(detail::Splits{3}); + + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + + // Use larger K sizes for stream-K tests + static constexpr int min_tiles_per_sk_unit = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::min_iters_per_sk_unit_; + problem_size_k = {TileShapeK * min_tiles_per_sk_unit, TileShapeK * 3 * min_tiles_per_sk_unit - max_alignment_k}; + } + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + std::vector raster_orders = {RasterOrderOptions::AlongM, RasterOrderOptions::AlongN}; + std::vector max_swizzle_sizes{detail::MaxSwizzleSize{1}, detail::MaxSwizzleSize{4}}; + + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (auto raster_order : raster_orders) { + for (auto max_swizzle_size : max_swizzle_sizes) { + for (DecompositionMode decomp_mode : decomposition_modes) { + + std::vector problem_splits = {detail::Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + auto max_splits = (k + TileShapeK - 1) / TileShapeK; + if (max_splits > 2) { + problem_splits.push_back(detail::Splits{2}); + } + if (max_splits > 3) { + problem_splits.push_back(detail::Splits{3}); + } + + problem_splits.push_back(detail::Splits{max_splits}); + + // Test the case in which we ask for more splits than there are K tiles in the GEMM. In this + // case, split-K will fall back to a splitting factor of `max_splits`. + problem_splits.push_back(detail::Splits{max_splits + 1}); + } + for (auto splits : problem_splits) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + try { + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, + max_swizzle_size, + splits, + decomp_mode + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception (unknown)"; + throw; + } + + EXPECT_TRUE(passed) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} failed"; + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // splits + } // decomposition_mode + } // max_swizzle_size + } // raster_order + } // k + } // n + } // m + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment_m, 256 + max_alignment_n, 160 + max_alignment_k, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +template +bool TestAllBiasElementwise(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::EXACT) { + return TestAll(alpha, beta, check_relative_equality); +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f18a7b39cbfe7dfb8d3251b2750e49261522de8a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp @@ -0,0 +1,1742 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Testbed and host reference for EVT unittest +*/ + + +#pragma once +#include "gemm_testbed_3x.hpp" + +namespace test { +namespace gemm { +namespace device { + +/// Host-side tapply, tapply in cute is HOST_DEVICE +template +constexpr auto +tapply(T&& t, F&& f, G&& g, cute::seq) +{ + return g(f(std::get(static_cast(t)))...); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT: Base class for EVT Node + +template < class ElementCompute_ > +class HostEVTNodeBase { +public: + using ElementCompute = ElementCompute_; + +private: + bool check_relative_equality_; + // Factors used for calculating relative equality. These default + // values are borrowed from those used by default in the CUTLASS + // profiler for performing relative equality checks. + float epsilon_ = 0.05f; + float nonzero_floor_ = 1.0f / 256.0f; + +public: + HostEVTNodeBase(){} + HostEVTNodeBase(bool check_relative_equality): + check_relative_equality_(check_relative_equality) { } + + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + if (check_relative_equality_) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, Element(epsilon_), Element(nonzero_floor_) + ); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + void* get_tensor_C_ptr() { + return nullptr; + } + + void* get_tensor_D_ptr() { + return nullptr; + } + + bool compare_reference(std::stringstream& error_ss) { + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Accumulator + +template< class ElementCompute = float > +class HostAccumulator: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + + struct Arguments { }; + +public: + HostAccumulator(){} + template + HostAccumulator(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(check_relative_equality) {} + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + cutlass::NumericConverter accumulator_converter; + return accumulator_converter(acc); + } + + Arguments get_arguments() { + return Arguments{}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Scalar Broadcast + +template < + int Value, + int BroadcastCount = 1, + class StrideMNL = cute::Stride, + template class ReductionFn = cutlass::multiplies, + class ElementCompute = float +> +class HostScalarBroadcast : public HostEVTNodeBase { +public: + + using Base = HostEVTNodeBase; + struct Arguments { + ElementCompute scalar[BroadcastCount] = {0}; + ElementCompute const* scalar_ptrs[BroadcastCount] = { nullptr }; + StrideMNL dScalar[BroadcastCount] = {}; + }; +private: + ElementCompute scalar_{}; + StrideMNL dScalar{}; + ElementCompute scalar_reduced_{}; +public: + HostScalarBroadcast(){} + + template + HostScalarBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality), scalar_(ElementCompute(Value)) { + scalar_ = ElementCompute(Value); + scalar_reduced_ = scalar_; + for (int i = 1; i < BroadcastCount; ++i) { + scalar_reduced_ = ReductionFn{}(scalar_reduced_, ElementCompute(Value)); + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return scalar_reduced_; + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss << "Scalar: " << float(scalar_) << "\n\n"; + return true; + } + + Arguments get_arguments() { + if constexpr (BroadcastCount == 1) + return Arguments{{scalar_}, {nullptr}, {dScalar}}; + else if constexpr (BroadcastCount == 2) + return Arguments{{scalar_, scalar_}, {nullptr, nullptr}, {dScalar, dScalar}}; + else if constexpr (BroadcastCount == 3) + return Arguments{{scalar_, scalar_, scalar_}, {nullptr, nullptr, nullptr}, {dScalar, dScalar, dScalar}}; + else + return Arguments{{scalar_}, {nullptr}, {dScalar}}; + } + + auto get_flatten_arguments() { + if constexpr (BroadcastCount == 1) { + return cute::make_tuple(scalar_, nullptr); + } + else if constexpr (BroadcastCount == 2) { + return cute::make_tuple(scalar_, scalar_, nullptr, nullptr); + } + else if constexpr (BroadcastCount == 3) { + return cute::make_tuple(scalar_, scalar_, scalar_, nullptr, nullptr, nullptr); + } + else { + return cute::make_tuple(scalar_, nullptr); + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Row Broadcast +template < + typename ElementBias_, + typename StrideMNL = cute::Stride, + typename ElementCompute = float +> +class HostRowBroadcast: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementBias = ElementBias_; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + ElementBias const* ptr_row = nullptr; + ElementBias null_default = ElementBias(0); + StrideMNL dRow = {}; + }; +private: + cutlass::NumericConverter bias_converter_; + cutlass::HostTensor bias_; + int N_; +public: + HostRowBroadcast(){} + template + HostRowBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + N_ = cute::get<1>(problem_shape_MNKL); + bias_.resize(cutlass::Coord<1>(N_)); + + EXPECT_TRUE( + detail::initialize_tensor( + bias_.host_view(), cutlass::Distribution::Uniform, + seed + ) + ); + bias_.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + auto TensorBias = cute::make_tensor(bias_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, N_))); + + return bias_converter_(TensorBias(1, n + n_b)); + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss + << "PerColumnBias = \n" << bias_.host_view() << "\n\n"; + return true; + } + + Arguments get_arguments() { + return {bias_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(bias_.device_data(), ElementBias(0), StrideMNL{}); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Column Broadcast +template < + typename ElementBias_, + typename StrideMNL = cute::Stride, + typename ElementCompute = float +> +class HostColBroadcast: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementBias = ElementBias_; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + ElementBias const* ptr_row = nullptr; + ElementBias null_default = ElementBias(0); + StrideMNL dRow = {}; + }; +private: + cutlass::NumericConverter bias_converter_; + cutlass::HostTensor bias_; + int M_; +public: + HostColBroadcast(){} + template + HostColBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + M_ = cute::get<0>(problem_shape_MNKL); + bias_.resize(cutlass::Coord<1>(M_)); + + EXPECT_TRUE( + detail::initialize_tensor( + bias_.host_view(), cutlass::Distribution::Uniform, + seed + ) + ); + bias_.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + auto TensorBias = cute::make_tensor(bias_.host_data(), + cute::make_layout(cute::make_shape(M_, cute::_1{}))); + + return bias_converter_(TensorBias(m + m_b, 1)); + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss + << "PerRowBias = \n" << bias_.host_view() << "\n\n"; + return true; + } + + Arguments get_arguments() { + return {bias_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(bias_.device_data(), ElementBias(0), StrideMNL{}); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Aux Load + +template < + typename ElementAuxLoad_, + typename LayoutTagAux_, + bool isC = false, + typename ElementCompute = float +> +class HostAuxLoad: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementAuxLoad = ElementAuxLoad_; + using LayoutTagAux = LayoutTagAux_; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + struct Arguments_Aux { + ElementAuxLoad const *ptr_aux = nullptr; + ElementAuxLoad null_default = ElementAuxLoad(0); + StrideAux dAux = {}; + }; + + struct Arguments_C {}; + + using Arguments = cute::conditional_t; + +private: + cutlass::NumericConverter aux_load_converter_; + cutlass::HostTensor tensor_aux_load_; + + int M_, N_, L_; + + StrideAux stride_aux_; +public: + HostAuxLoad(){} + template + HostAuxLoad(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_NMKL = cute::append<4>(problem_size, 1); + auto [M_, N_, K, L_] = problem_shape_NMKL; + auto aux_coord = cutlass::make_Coord(M_ * L_, N_); + tensor_aux_load_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + EXPECT_TRUE( + detail::initialize_tensor( + tensor_aux_load_.host_view(), + cutlass::Distribution::Uniform, + seed + ) + ); + tensor_aux_load_.sync_device(); + stride_aux_ = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(M_, N_, L_)); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + + auto TensorAuxLoad = cute::make_tensor(tensor_aux_load_.host_data(), + cute::make_layout(cute::make_shape(M_, N_, L_), stride_aux_)); + return aux_load_converter_(TensorAuxLoad(m + m_b, n + n_b, l)); + } + + bool compare_reference(std::stringstream& error_ss) { + if constexpr (!isC) { + error_ss + << "AuxLoad = \n" << tensor_aux_load_.host_view()<< "\n\n"; + } + return true; + } + + void* get_tensor_C_ptr() { + if constexpr (isC) { + return static_cast(tensor_aux_load_.device_data()); + } + else { + return nullptr; + } + } + + Arguments get_arguments() { + if constexpr (isC) + return {}; + else + return {tensor_aux_load_.device_data(), ElementAuxLoad(0), stride_aux_}; + } + + auto get_flatten_arguments() { + if constexpr (isC) + return cute::make_tuple(); + else + return cute::make_tuple(tensor_aux_load_.device_data(), ElementAuxLoad(0), stride_aux_); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Compute + +template +T* findNonNullPtr(T* first_ptr) { + return first_ptr; +} + +template +T* findNonNullPtr(T* first_ptr, Args... args) { + if (first_ptr) { + return first_ptr; + } + return findNonNullPtr(args...); +} + +template < + template class ComputeOp_, + typename ElementCompute = float +> +class HostCompute: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ComputeOp = ComputeOp_; + + struct Arguments { + struct OpArgs {} op; + }; +private: + ComputeOp op_; +public: + HostCompute(){} + template + HostCompute(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, Args... frg_inputs) { + return op_(frg_inputs...); + } + + Arguments get_arguments(){ + return {}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Aux Store + +template < + class ElementAuxStore_, + typename LayoutTagAux_, + bool isD = false, + bool isRelu = false, + typename ElementCompute = float +> +class HostAuxStore: public HostEVTNodeBase { +public: + using ElementAuxStore = ElementAuxStore_; + using LayoutTagAux = LayoutTagAux_; + + using Base = HostEVTNodeBase; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + struct Arguments_Aux { + struct OpArgs { + ElementAuxStore* ptr_aux = nullptr; + StrideAux dAux = {}; + } op; + }; + + struct Arguments_D {}; + + using Arguments = cute::conditional_t; + + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_aux_store_; + cutlass::HostTensor reference_aux_store_; + int M_, N_, L_; + StrideAux stride_aux_; +public: + HostAuxStore(){} + template + HostAuxStore(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M_, N_, K, L_] = problem_shape_MNKL; + auto aux_coord = cutlass::make_Coord(M_ * L_, N_); + tensor_aux_store_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + + reference_aux_store_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + tensor_aux_store_.sync_device(); + stride_aux_ = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(M_, N_, L_)); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + + auto TensorAuxStore = cute::make_tensor(detail::make_iterator(static_cast(reference_aux_store_.host_data())), + cute::make_layout(cute::make_shape(M_, N_, L_), stride_aux_)); + if constexpr (isRelu) + TensorAuxStore(m + m_b, n + n_b, l) = destination_converter_(child_0_result >= 0); + else + TensorAuxStore(m + m_b, n + n_b, l) = destination_converter_(child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_aux_store_.sync_host(); + + bool equal = this->equality_check(reference_aux_store_.host_view(), tensor_aux_store_.host_view()); + if (!equal) { + error_ss + << "\n\nReference =\n" << reference_aux_store_.host_view() + << "\n\nComputed =\n" << tensor_aux_store_.host_view() << "\n\n"; + } + return equal; + } + + void* get_tensor_D_ptr() { + if constexpr (isD) + return static_cast(tensor_aux_store_.device_data()); + else + return nullptr; + } + + Arguments get_arguments() { + if constexpr (isD) { + return {}; + } + else { + return {tensor_aux_store_.device_data(), stride_aux_}; + } + } + + auto get_flatten_arguments() { + if constexpr (isD) { + return cute::make_tuple(); + } + else { + return cute::make_tuple(tensor_aux_store_.device_data(), stride_aux_); + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Row Reduce + +template < + template class ReduceFn, + typename ElementReduce, + bool FinalReduction = true, // Should match the FinalReduction in Device type + typename CtaTileShapeMNK = cute::Shape, + typename ElementCompute = float +> +class HostRowReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementDst = cute::conditional_t; + + static constexpr int TileM = cute::get<0>(CtaTileShapeMNK{}); + static constexpr int TileN = cute::get<1>(CtaTileShapeMNK{}); + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_row = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dRow = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_row_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_row_reduce_; + int N_; + ReduceFn reduce_fn_; + + int extent_m_; + int extent_n_; + int extent_l_; +public: + HostRowReduce(){} + template + HostRowReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + N_ = cute::get<1>(problem_shape_MNKL); + if constexpr (FinalReduction) { + tensor_row_reduce_.resize(cutlass::Coord<1>(N_)); + reference_row_reduce_.resize(cutlass::Coord<1>(N_)); + reduce_buffer_.resize(cutlass::Coord<1>(N_)); + } + else { + auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{})); + extent_m_ = cute::get<0>(NumTile); + extent_n_ = cute::get<1>(NumTile) * TileN; + extent_l_ = cute::get<2>(NumTile); + auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_); + tensor_row_reduce_.resize(shape); + reference_row_reduce_.resize(shape); + reduce_buffer_.resize(shape); + } + + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + if constexpr (FinalReduction) { + auto TensorRowReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, N_))); + TensorRowReduce(1, n + n_b) = reduce_fn_(TensorRowReduce(1, n + n_b), child_0_result); + } + else { + auto TensorRowReduce = cute::make_tensor( + reduce_buffer_.host_data(), + cute::make_layout( + cute::make_shape(extent_m_, extent_n_, extent_l_), + cute::make_stride(extent_n_, 1, extent_m_ * extent_l_) + ) + ); + TensorRowReduce((m+m_b)/TileM, n+n_b, l) = reduce_fn_(TensorRowReduce((m+m_b)/TileM, n+n_b, l), child_0_result); + } + + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_row_reduce_.sync_host(); + + auto TensorRowReduce = cute::make_tensor(reference_row_reduce_.host_data(), + cute::make_layout(cute::make_shape(reference_row_reduce_.size()))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(reduce_buffer_.size()))); + + // Filling the reference tensor with the reduce buffer + for (uint64_t n = 0; n < size(TensorRowReduce); n ++) { + TensorRowReduce(n) = destination_converter_(TensorReduceBuffer(n)); + } + + bool equal = this->equality_check(reference_row_reduce_.host_view(), tensor_row_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nRow Reduce Reference =\n" << reference_row_reduce_.host_view() + << "\n\nRow Reduce Computed =\n" << tensor_row_reduce_.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {tensor_row_reduce_.device_data()}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Column Reduce + +template < + template class ReduceFn, + typename ElementReduce, + bool FinalReduction = true, // Should match the FinalReduction in Device type + typename CtaTileShapeMNK = cute::Shape, + typename ElementCompute = float +> +class HostColumnReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementDst = cute::conditional_t; + + static constexpr int TileM = cute::get<0>(CtaTileShapeMNK{}); + static constexpr int TileN = cute::get<1>(CtaTileShapeMNK{}); + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_col = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dRow = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_column_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_column_reduce_; + int M_; + ReduceFn reduce_fn_; + + int extent_m_; + int extent_n_; + int extent_l_; +public: + HostColumnReduce(){} + template + HostColumnReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + M_ = cute::get<0>(problem_shape_MNKL); + + if constexpr (FinalReduction) { + tensor_column_reduce_.resize(cutlass::Coord<1>(M_)); + reference_column_reduce_.resize(cutlass::Coord<1>(M_)); + reduce_buffer_.resize(cutlass::Coord<1>(M_)); + } + else { + auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{})); + extent_m_ = cute::get<0>(NumTile) * TileM; + extent_n_ = cute::get<1>(NumTile); + extent_l_ = cute::get<2>(NumTile); + auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_); + tensor_column_reduce_.resize(shape); + reference_column_reduce_.resize(shape); + reduce_buffer_.resize(shape); + } + + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorColReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(M_, cute::_1{}))); + if constexpr (FinalReduction) { + TensorColReduce(m + m_b, 1) = reduce_fn_(TensorColReduce(m + m_b, 1), child_0_result); + } + else { + auto shape = reduce_buffer_.extent(); + auto TensorColReduce = cute::make_tensor( + reduce_buffer_.host_data(), + cute::make_layout( + cute::make_shape(extent_m_, extent_n_, extent_l_), + cute::make_stride(1, extent_m_, extent_m_ * extent_l_) + ) + ); + TensorColReduce(m+m_b, (n+n_b)/TileN, l) = reduce_fn_(TensorColReduce(m+m_b, (n+n_b)/TileN, l), child_0_result); + } + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_column_reduce_.sync_host(); + + auto TensorColReduce = cute::make_tensor(reference_column_reduce_.host_data(), + cute::make_layout(cute::make_shape(reference_column_reduce_.size()))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(reduce_buffer_.size()))); + + // Filling the reference tensor with the reduce buffer + for (uint64_t m = 0; m < size(TensorColReduce); m ++) { + TensorColReduce(m) = destination_converter_(TensorReduceBuffer(m)); + } + + bool equal = this->equality_check(reference_column_reduce_.host_view(), tensor_column_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nColumn Reduce Reference =\n" << reference_column_reduce_.host_view() + << "\n\nColumn Reduce Computed =\n" << tensor_column_reduce_.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {tensor_column_reduce_.device_data()}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Scalar Reduce + +template < + template class ReduceFn, + typename ElementReduce, + typename ElementCompute = float, + bool enabled = true +> +class HostScalarReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_scalar = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dScalar = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_scalar_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_scalar_reduce_; + ReduceFn reduce_fn_; +public: + HostScalarReduce(){} + template + HostScalarReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + tensor_scalar_reduce_.resize(cutlass::Coord<1>(1)); + reference_scalar_reduce_.resize(cutlass::Coord<1>(1)); + reduce_buffer_.resize(cutlass::Coord<1>(1)); + + tensor_scalar_reduce_.sync_device(); + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorRowReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + TensorRowReduce(0) = reduce_fn_(TensorRowReduce(0), child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + if constexpr (enabled) { + // Verify the store node + tensor_scalar_reduce_.sync_host(); + + auto TensorRowReduce = cute::make_tensor(reference_scalar_reduce_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + + // Filling the reference tensor with the reduce buffer + TensorRowReduce(0) = destination_converter_(TensorReduceBuffer(0)); + + bool equal = this->equality_check(reference_scalar_reduce_.host_view(), tensor_scalar_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nScalar Reduce Reference =\n" << reference_scalar_reduce_.host_view() + << "\n\nScalar Reduce Computed =\n" << tensor_scalar_reduce_.host_view() << "\n\n"; + } + return equal; + } + else { + return true; + } + + } + + Arguments get_arguments() { + return {tensor_scalar_reduce_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(tensor_scalar_reduce_.device_data()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Host EVT wrapper + +/// The ArgumentPack is used to model the alignment when num ops <= 4 +template +struct ArgumentPack; + +template +struct ArgumentPack { + T arg; + ArgumentPack(T first): + arg(first) {} +}; + +template +struct ArgumentPack { + First arg; + ArgumentPack rest_args; + + ArgumentPack(First first, Rest... rest) : + arg(first), rest_args(rest...) {} +}; + + +/// Base class for Host Visitor +template +struct HostVisitorBase: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + + using Arguments_struct = ArgumentPack; + using Arguments_tuple = cute::tuple; + + constexpr static int Rm1 = sizeof...(Ops); + constexpr static bool cond = Rm1 > 4; + using Arguments = cute::conditional_t; + + std::tuple ops; + + HostVisitorBase(){} + template + HostVisitorBase(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(check_relative_equality), + ops(test::gemm::device::tapply(std::tuple{}, + [&] (auto&& op) { + using Op = cute::remove_cvref_t; + return Op(problem_size, check_relative_equality, seed); + }, + [] (auto&&... _ops) { + return std::make_tuple(_ops...); + }, + cute::make_seq{} + )){ } + + bool compare_reference(std::stringstream& error_ss) { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.compare_reference(error_ss); + }, + [&] (auto&&... inputs) { + return arrayAnd(inputs...); + }, + cute::make_seq{} + ); + } + + void* get_tensor_C_ptr() { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.get_tensor_C_ptr(); + }, + [&] (auto&&... inputs) { + return findNonNullPtr(inputs...); + }, + cute::make_seq{} + ); + } + + void* get_tensor_D_ptr() { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.get_tensor_D_ptr(); + }, + [&] (auto&&... inputs) { + return findNonNullPtr(inputs...); + }, + cute::make_seq{} + ); + } + + Arguments get_arguments() { + return test::gemm::device::tapply(ops, + [&](auto& op) { + return op.get_arguments(); + }, + [&] (auto&&... args) { + if constexpr (Rm1 > 4) { + return cute::make_tuple(args...); + } + else { + return Arguments(args...); + } + }, + cute::make_seq{} + ); + } + + auto get_flatten_arguments() { + return test::gemm::device::tapply(ops, + [&](auto& op) { + return op.get_flatten_arguments(); + }, + [&] (auto&&... args) { + return flatten(cute::make_tuple(args...)); + }, + cute::make_seq{} + ); + } + + bool arrayAnd(bool passed) { + return passed; + } + + template + bool arrayAnd(bool first_passed, Args... passed) { + if (first_passed) { + return arrayAnd(passed...); + } + return first_passed; + } + +}; + + +/// Tree-struct visitor +template +struct HostTreeVisitor: public HostVisitorBase { +public: + using ElementCompute = typename NodeOp::Base::ElementCompute; + using Base = HostVisitorBase; + using Arguments = typename Base::Arguments; + + constexpr static int Rm1 = sizeof...(ChildOps); + + HostTreeVisitor(){} + template + HostTreeVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed){ } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + return cute::detail::tapply(this->ops, + [&] (auto& op) { + return op.visit(m, n, l, m_b, n_b, acc); + }, + [&] (auto&&... frg_inputs) { + return std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); + }, + cute::make_seq{} + ); + } +}; + + +/// General Graph visitor +template +struct HostTopoVisitor: public HostVisitorBase { +public: + using Base = HostVisitorBase; + constexpr static int Rm1 = Base::Rm1; + using Arguments = typename Base::Arguments; + +private: + ElementCompute frg_outputs_[Rm1]; +public: + HostTopoVisitor(){} + template + HostTopoVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed) { } + + template + ElementCompute visit_( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + frg_outputs_[I] = cute::transform_apply(cute::get(EdgeTuple{}), + [&] (auto&& _E) { + constexpr int e = cute::remove_cvref_t::value; + return frg_outputs_[e]; + }, + [&] (auto const&... frg_inputs) { + ElementCompute res = std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); + return res; + } + ); + + if constexpr (I < Rm1 - 1) { + return visit_(m, n, l, m_b, n_b, acc); + } + else { + return frg_outputs_[I]; + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return visit_(m, n, l, m_b, n_b, acc); + } + +}; + + +/// SplitTree visitor +template +struct HostSplitTreeVisitor: public HostVisitorBase { +public: + using Base = HostVisitorBase; + using Arguments = typename Base::Arguments; + + constexpr static int Rm2 = sizeof...(AuxOutTrees); + +private: + ElementCompute frg_input_; +public: + HostSplitTreeVisitor(){} + template + HostSplitTreeVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed) { } + + template + void visitAux( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator frag) { + std::get(this->ops).visit(m, n, l, m_b, n_b, frag); + + if constexpr (I < Rm2 - 1) { + return visitAux(m, n, l, m_b, n_b, frag); + } + else { + return; + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + /// Compute the input tree + frg_input_ = std::get<0>(this->ops).visit(m, n, l, m_b, n_b, acc); + + /// Compute the aux out tree + visitAux(m, n, l, m_b, n_b, frg_input_); + /// Visit the output tree + return std::get(this->ops).visit(m, n, l, m_b, n_b, frg_input_); + } +}; + +/// Universal testbed for EVT w/o smem +template +class Testbed3xEVTnoSmem { +public: + // The EVT Module to test + using EVTModule = EVT; //typename EVT::EVTModule; + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementC = typename Kernel::ElementC; + using ElementD = typename Kernel::ElementD; + + using ProblemShapeType = typename Kernel::ProblemShape; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + // + // Methods + // + Testbed3xEVTnoSmem( + bool check_relative_equality_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed ) : + impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(check_relative_equality_) { } + + Testbed3xEVTnoSmem( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B tensor + // + impl_.initialize(problem_size); + } + // Detail Implementation + TestBedImpl impl_; + + // Whether to use relative equality checks + bool check_relative_equality; + + bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { + + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + /// Reference Kernel + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + /// Epilogue EVT + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { + host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); + } + } + } + } + } + } + + std::stringstream error_ss; + bool passed = host_reference.compare_reference(error_ss); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K + << ", Batch count = " << L << "\n\n"; + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view(); + + file << error_ss.str(); + } + + return passed; + } + + bool run( + ProblemShapeType problem_size, + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, + int iterations = 20, + bool profiling = false) { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the Gemm operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + else { + scheduler_args = { static_cast(max_swizzle), raster_order }; + } + + /// Initializes data structures + /// A/B/C/D Tensor + initialize(problem_size); + + /// Initialize the epilogue arguments + EVTModule host_reference(problem_size, check_relative_equality, 2024); + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b + }, + {}, + hw_info, + scheduler_args + }; + + // Filling in the thread arguments + if constexpr (FlatArgs) { + auto epilogue_args = host_reference.get_flatten_arguments(); + std::memcpy(&arguments.epilogue.thread, &epilogue_args, sizeof(epilogue_args)); + + arguments.epilogue.ptr_C = static_cast(host_reference.get_tensor_C_ptr()); + arguments.epilogue.dC = impl_.collective_epilogue.stride_c; + + arguments.epilogue.ptr_D = static_cast(host_reference.get_tensor_D_ptr()); + arguments.epilogue.dD = impl_.collective_epilogue.stride_d; + } + else { + auto epilogue_args = host_reference.get_arguments(); + std::memcpy(&arguments.epilogue, &epilogue_args, sizeof(epilogue_args)); + } + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, host_reference); + if (!passed) { + std::cout << "Error : Failed \n"; + } + + return passed; + } +}; + +/// Universal testbed for EVT +template +class Testbed3xEVT { +public: + // The EVT Module to test + using EVTModule = typename EVT::EVTModule; + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementC = typename Kernel::ElementC; + using ElementD = typename Kernel::ElementD; + + using ProblemShapeType = typename Kernel::ProblemShape; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + + // + // Methods + // + Testbed3xEVT( + bool check_relative_equality_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(check_relative_equality_) { } + + Testbed3xEVT( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + Testbed3xEVT( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, stride_factor_B_, stride_factor_C_, stride_factor_D_, + CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B tensor + // + impl_.initialize(problem_size); + } + // Detail Implementation + TestBedImpl impl_; + + // Whether to use relative equality checks + bool check_relative_equality; + + bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { + + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + /// Reference Kernel + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + /// Epilogue EVT + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { + host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); + } + } + } + } + } + } + + std::stringstream error_ss; + bool passed = host_reference.compare_reference(error_ss); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K + << ", Batch count = " << L << "\n\n"; + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() + << "\nC =\n" << impl_.collective_epilogue.tensor_C.host_view() << "\n\n"; + + file << error_ss.str(); + } + + return passed; + } + + bool run( + ProblemShapeType problem_size, + bool profiling = false, + int iterations = 20, + int splits = 1) { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the Gemm operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { splits }; + } + + /// Initializes data structures + /// A/B/C/D Tensor + initialize(problem_size); + + /// Initialize the epilogue arguments + EVTModule host_reference(problem_size, check_relative_equality, 2024); + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b + }, + { // Epilogue arguments + {}, // thread + static_cast(host_reference.get_tensor_C_ptr()), + impl_.collective_epilogue.stride_c, + static_cast(host_reference.get_tensor_D_ptr()), + impl_.collective_epilogue.stride_d + }, // Epilogue arguments end + hw_info, + scheduler_args + }; + + // Filling in the thread arguments + typename EVTModule::Arguments epilogue_args = host_reference.get_arguments(); + std::memcpy(&arguments.epilogue.thread, &epilogue_args.arg, sizeof(epilogue_args.arg)); + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, host_reference); + if (!passed) { + std::cout << "Error : Failed \n"; + } + + return passed; + } +}; + +template +bool TestAllEVT(bool check_relative_equality = false) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3xEVT testbed(check_relative_equality); + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run(problem_size); + + if (!passed) { + return false; + } + } + } + } + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbc54ec582d88d9039968d8153cf6127a06ec274 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -0,0 +1,2409 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Testbed for Ptr-Array and Grouped GEMM interface +*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/complex.h" +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" +#include "cute/layout.hpp" +#include "cute/numeric/int.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class ScalarLoc { + ON_HOST = 0, + ON_DEVICE = 1 +}; + +enum class VectorScale { + DISABLED = 0, + ENABLED = 1 +}; + +enum class CheckEquality { + EXACT = 0, + RELATIVE = 1 +}; + +namespace detail{ + +// Helper classes that take default data type when +// the Gemm::EpilogueOutputOp does not have ElementCompute +// and ElementScalar. +// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) +template +struct ElementComputeType { + using Type = Default; +}; + +template +struct ElementComputeType> { + using Type = typename Gemm::EpilogueOutputOp::ElementCompute; +}; + +template +struct ElementScalarType { + using Type = Default; +}; + +template +struct ElementScalarType> { + using Type = typename Gemm::EpilogueOutputOp::ElementScalar; +}; + + +template +struct IsF8F6F4Kernel { + static constexpr bool value = false; +}; + +template +struct IsF8F6F4Kernel> { + static constexpr bool value = true; +}; + + +// The maximum swizzle size to use +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +class MaxSwizzleSize { +public: + MaxSwizzleSize() = default; + + template && + !cute::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +template +struct IsDefaultEpilogue { + static constexpr bool value = false; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +// The number of splits to test. +// +// This class makes it harder to confuse the order of arguments +// of the various run(...) functions in this file. The constructor +// is explicit, so one can't just type 42 (or false, which the +// compiler unhelpfully turns into 0); one has to type Splits(42). +// Splits() picks the default number of splits, 1. +// +// The conversion-to-int operator (operator int()) MUST be explicit! +// Conversion to int MUST require static_cast. +// Otherwise, that defeats a key purpose of this class, +// which is to catch common errors of confusing the order +// of function arguments. +class Splits { +public: + Splits() = default; + + template && + !cute::is_same_v)) > + explicit Splits(IntegralNotBool splits) : splits_(splits) {} + explicit operator int() const { return splits_; } +private: + int splits_ = 1; +}; + +// The number of iterations to test. +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +// Iterations() picks the default number of iterations, 20. +class Iterations { +public: + Iterations() = default; + + template && + !cute::is_same_v)) > + explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} + explicit operator int() const { return iterations_; } +private: + int iterations_ = 20; +}; + +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + + else if (bits_input <= 8) { + + if constexpr ( + cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + + scope_max = 1; + scope_min = -1; + + } + + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Looks at Cute Stride to check Row / Column Major +template +static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); +} + + +// +// Default MMA input Operands : A , B +// +template< + class ScheduleType_, + class Gemm, + class ElementA_ = typename Gemm::GemmKernel::ElementA, + class ElementB_ = typename Gemm::GemmKernel::ElementB> +struct HostCollectiveMainloop { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; + + std::vector stride_a_host; + std::vector stride_b_host; + + cutlass::DeviceAllocation stride_a_device; + cutlass::DeviceAllocation stride_b_device; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + std::vector> tensors_A; + std::vector> tensors_B; + cutlass::DeviceAllocation device_tensors_A; + cutlass::DeviceAllocation device_tensors_B; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_), + check_relative_equality(check_relative_equality_) { } + + bool initialize(ProblemShapeType problem_shapes) { + // + // Allocate the GEMM workspace + // + // for pointer array problem_shapes.groups() is 1 + + tensors_A.clear(); + tensors_B.clear(); + stride_a_host.clear(); + stride_b_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + for(int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N); + + tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); + tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); + EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_A[i].host_view().at({0, 0}) = ElementA(1); + tensors_B[i].host_view().at({0, 0}) = ElementB(1); + + tensors_A[i].sync_device(); + tensors_B[i].sync_device(); + } + + return true; + } + + Arguments to_args(ProblemShapeType problem_shapes) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + std::vector ptr_A_host(L); + std::vector ptr_B_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_A_host.at(i) = tensors_A[i].device_data(); + ptr_B_host.at(i) = tensors_B[i].device_data(); + } + + device_tensors_A.reset(L); + device_tensors_A.copy_from_host(ptr_A_host.data()); + + device_tensors_B.reset(L); + device_tensors_B.copy_from_host(ptr_B_host.data()); + + stride_a_device.reset(problem_shapes.groups()); + stride_a_device.copy_from_host(stride_a_host.data()); + stride_b_device.reset(problem_shapes.groups()); + stride_b_device.copy_from_host(stride_b_host.data()); + + Arguments arguments; + + if constexpr (IsGroupGemm) { + arguments + = + { + device_tensors_A.get(), stride_a_device.get(), device_tensors_B.get(), stride_b_device.get() + }; + } + else { + arguments = + { + device_tensors_A.get(), stride_a_host[0], device_tensors_B.get(), stride_b_host[0] + }; + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), + make_layout(make_shape(M, K, 1), stride_a_host[batch])); + auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), + make_layout(make_shape(N, K, 1), stride_b_host[batch])); + + cutlass::reference::host::GettMainloopParams mainloop_params{}; + + mainloop_params.A = A; + mainloop_params.B = B; + mainloop_params.transform_A = TransformA; + mainloop_params.transform_B = TransformB; + + return mainloop_params; + } + + void print_tensors(std::ofstream& file, int batch) { + file << "A =\n" << tensors_A[batch].host_view() + << "\nB =\n" << tensors_B[batch].host_view(); + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, int batch) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); + + bool passed = true; + return passed; + } +}; + + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using InternalLayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using InternalLayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + std::vector stride_a_host; + std::vector stride_b_host; + cutlass::DeviceAllocation stride_a_device; + cutlass::DeviceAllocation stride_b_device; + + std::vector layout_sfa_host; + std::vector layout_sfb_host; + cutlass::DeviceAllocation layout_sfa_device; + cutlass::DeviceAllocation layout_sfb_device; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + std::vector> tensors_A; + std::vector> tensors_B; + std::vector> tensors_SFA; + std::vector> tensors_SFB; + + cutlass::DeviceAllocation device_tensors_A; + cutlass::DeviceAllocation device_tensors_B; + cutlass::DeviceAllocation device_tensors_SFA; + cutlass::DeviceAllocation device_tensors_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_shapes) { + // + // Allocate the GEMM workspace + // + + tensors_A.clear(); + tensors_B.clear(); + stride_a_host.clear(); + stride_b_host.clear(); + tensors_SFA.clear(); + tensors_SFB.clear(); + layout_sfa_host.clear(); + layout_sfb_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N); + + tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); + tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); + EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_A[i].host_view().at({0, 0}) = ElementA(1); + tensors_B[i].host_view().at({0, 0}) = ElementB(1); + + tensors_A[i].sync_device(); + tensors_B[i].sync_device(); + + using namespace cute; + + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1))); + layout_sfb_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1))); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{}, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{}, k_blks * Blk_SF{}); + + tensors_SFA.push_back(cutlass::HostTensor(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A))); + tensors_SFB.push_back(cutlass::HostTensor(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_SFA[i].host_view(), init_A, seed + 2024 + i)); + EXPECT_TRUE(initialize_tensor(tensors_SFB[i].host_view(), init_B, seed + 2025 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_SFA[i].host_view().at({0, 0}) = ElementSF(1); + tensors_SFB[i].host_view().at({0, 0}) = ElementSF(1); + + tensors_SFA[i].sync_device(); + tensors_SFB[i].sync_device(); + } + + return true; + } + + Arguments to_args(ProblemShapeType problem_shapes) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + std::vector ptr_A_host(L); + std::vector ptr_B_host(L); + std::vector ptr_SFA_host(L); + std::vector ptr_SFB_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_A_host.at(i) = tensors_A[i].device_data(); + ptr_B_host.at(i) = tensors_B[i].device_data(); + ptr_SFA_host.at(i) = tensors_SFA[i].device_data(); + ptr_SFB_host.at(i) = tensors_SFB[i].device_data(); + } + + device_tensors_A.reset(L); + device_tensors_A.copy_from_host(ptr_A_host.data()); + + device_tensors_B.reset(L); + device_tensors_B.copy_from_host(ptr_B_host.data()); + + device_tensors_SFA.reset(L); + device_tensors_SFA.copy_from_host(ptr_SFA_host.data()); + + device_tensors_SFB.reset(L); + device_tensors_SFB.copy_from_host(ptr_SFB_host.data()); + + stride_a_device.reset(problem_shapes.groups()); + stride_a_device.copy_from_host(stride_a_host.data()); + + stride_b_device.reset(problem_shapes.groups()); + stride_b_device.copy_from_host(stride_b_host.data()); + + layout_sfa_device.reset(problem_shapes.groups()); + layout_sfa_device.copy_from_host(layout_sfa_host.data()); + + layout_sfb_device.reset(problem_shapes.groups()); + layout_sfb_device.copy_from_host(layout_sfb_host.data()); + + if constexpr (IsGroupGemm) { + return Arguments{ + device_tensors_A.get(), stride_a_device.get(), + device_tensors_B.get(), stride_b_device.get(), + device_tensors_SFA.get(), layout_sfa_device.get(), + device_tensors_SFB.get(), layout_sfb_device.get() + }; + } + else { + return Arguments{ + device_tensors_A.get(), stride_a_host[0], + device_tensors_B.get(), stride_b_host[0], + device_tensors_SFA.get(), layout_sfa_host[0], + device_tensors_SFB.get(), layout_sfb_host[0] + }; + } + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), + make_layout(make_shape(M, K, 1), stride_a_host[batch])); + auto SfA = make_tensor(tensors_SFA[batch].host_data(), layout_sfa_host[batch]); + + auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), + make_layout(make_shape(N, K, 1), stride_b_host[batch])); + auto SfB = make_tensor(tensors_SFB[batch].host_data(), layout_sfb_host[batch]); + + return cutlass::reference::host::GettMainloopParams + {A, SfA, B, SfB}; + } + + void print_tensors(std::ofstream& file, int batch) { + file << "A =\n" << tensors_A[batch].host_view() + << "\nB =\n" << tensors_B[batch].host_view() + << "\nSFA =\n" << tensors_SFA[batch].host_view() + << "\nSFB =\n" << tensors_SFB[batch].host_view(); + } + + bool compare_reference( + ProblemShapeType problem_shapes, int batch) { + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFA[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFB[batch].host_view()), 0); + return true; + } +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +template +struct HostCollectiveDefaultEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using InternalStrideD = typename kernel::InternalStrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + using InternalStrideC = typename kernel::InternalStrideC; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using FusionOp = typename Gemm::EpilogueOutputOp; + + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + cutlass::DeviceAllocation stride_c_device; + cutlass::DeviceAllocation stride_d_device; + + std::vector stride_c_host; + std::vector stride_d_host; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + ElementScalar alpha; + ElementScalar beta; + + std::vector> tensors_C; + std::vector> tensors_D; + std::vector> references_D; + cutlass::DeviceAllocation device_tensors_C; + cutlass::DeviceAllocation device_tensors_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveDefaultEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + + tensors_C.clear(); + tensors_D.clear(); + references_D.clear(); + stride_c_host.clear(); + stride_d_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); + stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M, N); + + tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); + tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); + references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); + EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); + tensors_C[i].host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); + tensors_C[i].sync_device(); + tensors_D[i].sync_device(); + } + alpha = alpha_; + beta = beta_; + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + tensors_D[batch].sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); + + if (tensors_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); + } + + if (references_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); + } + + bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + std::vector ptr_C_host(L); + std::vector ptr_D_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_C_host.at(i) = tensors_C[i].device_data(); + ptr_D_host.at(i) = tensors_D[i].device_data(); + } + + device_tensors_C.reset(L); + device_tensors_C.copy_from_host(ptr_C_host.data()); + + device_tensors_D.reset(L); + device_tensors_D.copy_from_host(ptr_D_host.data()); + + stride_c_device.reset(problem_shapes.groups()); + stride_c_device.copy_from_host(stride_c_host.data()); + + stride_d_device.reset(problem_shapes.groups()); + stride_d_device.copy_from_host(stride_d_host.data()); + + Arguments arguments; + if constexpr (IsGroupGemm) { + arguments = + { + {alpha, beta}, + device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + }; + } + else { + arguments = + { + {alpha, beta}, + device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + }; + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + L = std::max(problem_shapes.groups(), L); + + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); + auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D)> + epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha; + epilogue_params.beta = beta; + + return epilogue_params; + } +}; + +template +struct HostCollectiveEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using InternalStrideD = typename kernel::InternalStrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + using InternalStrideC = typename kernel::InternalStrideC; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + + // + // FusionOperation derived types/queries + // + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool IsLegacy = + cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize> + >; + + using FusionOp = typename Gemm::EpilogueOutputOp; + static_assert(cute::is_base_of_v); + + + // Scale factor Generation related + using SfStrategy = cutlass::reference::host::SfStrategy; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; + static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; + using ElementSFD = non_void_t, ElementD>; + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + SFD_VectorSize + >; + using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; + std::vector> tensors_SFD; + std::vector> references_SFD; + cutlass::DeviceAllocation device_tensors_SFD; + + using ElementCompute = typename FusionOp::ElementCompute; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementBias = non_void_t; + using ElementAux = non_void_t; + using ElementAmax = non_void_t; + using LayoutTagAux = non_void_t; + using ActivationFunctor = non_void_t>; + + static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; + static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; + static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; + static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + cutlass::DeviceAllocation stride_c_device; + cutlass::DeviceAllocation stride_d_device; + + std::vector stride_c_host; + std::vector stride_d_host; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + cutlass::HostTensor alpha; + cutlass::HostTensor beta; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor bias; + std::vector> tensors_C; + cutlass::DeviceAllocation device_tensors_C; + cutlass::HostTensor norm_constant; + + // Outputs + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + std::vector> tensors_Aux; + cutlass::DeviceAllocation device_tensors_Aux; + cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + std::vector> tensors_D; + std::vector> references_D; + cutlass::DeviceAllocation device_tensors_D; + + // References + cutlass::HostTensor reference_dbias; + std::vector> references_Aux; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors + cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; + // Random distribution with which to initialize the bias vector + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + + tensors_C.clear(); + tensors_D.clear(); + references_D.clear(); + stride_c_host.clear(); + stride_d_host.clear(); + + tensors_SFD.clear(); + references_SFD.clear(); + + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); + stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); + + auto c_coord = cutlass::make_Coord(M, N); + tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); + tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); + references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); + EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); + tensors_C[i].host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); + tensors_C[i].sync_device(); + tensors_D[i].sync_device(); + } + + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); + if constexpr (IsPerRowScaleEnabled) { + alpha.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (vector_scale_mode == VectorScale::DISABLED) { + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + else { + beta.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + } + else { + alpha.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + beta.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + alpha.sync_device(); + beta.sync_device(); + + if constexpr (IsScaleFactorEnabled) { + scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); + EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); + EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); + EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + } + + if constexpr (IsBiasEnabled) { + bias.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); + bias.sync_device(); + } + + if constexpr (IsDeBiasEnabled) { + bias.resize(col_vector_coord); + reference_dbias.resize(col_vector_coord); + cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); + cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabledD) { + abs_max_D.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_D.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_D.sync_device(); + reference_abs_max_D.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); + } + + tensors_Aux.clear(); + references_Aux.clear(); + + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxInEnabled)); + + if constexpr (IsAuxInEnabled) { + auto aux_coord = cutlass::make_Coord(M, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + for (int32_t i = 0; i < L; ++i) { + tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); + EXPECT_TRUE(initialize_tensor(tensors_Aux[i].host_view(), init_C, seed + 2023)); + tensors_Aux[i].sync_device(); + } + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); + } + + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxOutEnabled)); + + if constexpr (IsAuxOutEnabled) { + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto aux_coord = cutlass::make_Coord(M, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); + references_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout, false)); + tensors_Aux[i].sync_device(); + } + + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); + + if constexpr (IsScaleFactorEnabled) { + scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); + scale_Aux.sync_device(); + } + + if constexpr (IsAbsMaxEnabledAux) { + abs_max_Aux.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_Aux.sync_device(); + reference_abs_max_Aux.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); + } + } + + + if constexpr (IsBlockScaleSupported) { + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, _] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + // If block scaled output is supported we always have at least 1 SFD + auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); + auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); + auto sfd_coord = [&] () { + return cutlass::make_Coord(m_blks * Blk_MN{}, n_blks * Blk_SF{}); + }(); + tensors_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D))); + references_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false)); + tensors_SFD[i].sync_device(); + } + norm_constant.resize(scalar_coord, true); + EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); + norm_constant.sync_device(); + } + + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) { + tensors_D[batch].sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); + + if (tensors_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); + } + + if (references_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); + } + + bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + std::vector ptr_C_host(L); + std::vector ptr_D_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_C_host.at(i) = tensors_C[i].device_data(); + ptr_D_host.at(i) = tensors_D[i].device_data(); + } + + device_tensors_C.reset(L); + device_tensors_C.copy_from_host(ptr_C_host.data()); + + device_tensors_D.reset(L); + device_tensors_D.copy_from_host(ptr_D_host.data()); + + stride_c_device.reset(problem_shapes.groups()); + stride_c_device.copy_from_host(stride_c_host.data()); + + stride_d_device.reset(problem_shapes.groups()); + stride_d_device.copy_from_host(stride_d_host.data()); + + std::vector ptr_Aux_host(L); + if constexpr (IsAuxInEnabled || IsAuxOutEnabled) { + for (int32_t i = 0; i < L; ++i) { + ptr_Aux_host.at(i) = tensors_Aux[i].device_data(); + } + device_tensors_Aux.reset(L); + device_tensors_Aux.copy_from_host(ptr_Aux_host.data()); + } + + auto device_tensors_C_ptr = cute::is_void_v ? nullptr : + reinterpret_cast(device_tensors_C.get()); + + Arguments arguments; + if constexpr (IsGroupGemm) { + arguments = + { + {}, + device_tensors_C_ptr, stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + }; + } + else { + arguments = + { + {}, + device_tensors_C_ptr, stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + }; + } + + auto &fusion_args = arguments.thread; + if constexpr (IsLegacy) { + arguments.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.ptr_Bias = bias.device_data(); + arguments.ptr_T = device_tensors_Aux.get(); + } + else { + fusion_args.alpha = alpha.at(coord_0); + fusion_args.beta = beta.at(coord_0); + + fusion_args.alpha_ptr = alpha.device_data(); + // can_implement requires beta_ptr to not be set if its voidC + fusion_args.beta_ptr = cute::is_void_v ? nullptr : + beta.device_data(); + + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); + } + + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } + + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); + } + + // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty + if constexpr (cute::is_same_v>) { + fusion_args.activation.scale = ElementCompute(1); + } + + // Treat Clamp as ReLU + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = 0; + fusion_args.activation.upper_bound = std::numeric_limits::max(); + } + + if constexpr (IsAbsMaxEnabledD) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = device_tensors_Aux.get(); + fusion_args.dAux = stride_Aux; + } + + if constexpr (IsAuxOutEnabled) { + fusion_args.aux_ptr = device_tensors_Aux.get(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabledAux) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } + } + + if constexpr (IsBlockScaleSupported) { + std::vector ptr_SFD_host(L); + for (int32_t i = 0; i < L; ++i) { + ptr_SFD_host.at(i) = tensors_SFD[i].device_data(); + } + device_tensors_SFD.reset(L); + device_tensors_SFD.copy_from_host(ptr_SFD_host.data()); + + arguments.thread.block_scale_factor_ptr = device_tensors_SFD.get(); + arguments.thread.norm_constant_ptr = norm_constant.device_data(); + } + + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto [M, N, K, L] = problem_shape_MNKL; + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); + auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); + auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + auto Aux_layout = cute::make_layout(cute::make_shape(M, N, 1), stride_Aux); + auto Aux = [&]() { + auto ptr = recast_ptr(nullptr); + if (IsAuxInEnabled) { + ptr = detail::make_iterator(tensors_Aux[batch].host_data()); + } else if (IsAuxOutEnabled) { + ptr = detail::make_iterator(references_Aux[batch].host_data()); + } + return cute::make_tensor(ptr, Aux_layout); + }(); + auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, M))); + auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, N))); + + auto SfD = [&](){ + if constexpr (IsBlockScaleSupported) { + auto tensor = make_tensor(detail::make_iterator(references_SFD[batch].host_data()), + Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + return tensor; + } + else { + // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. + return D; + } + }(); + + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + decltype(Bias), + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), + ActivationFunctor + , decltype(SfD) + , Int + , cutlass::plus + , false + , SfGenStrategy + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha.at(coord_0); + epilogue_params.beta = beta.at(coord_0); + + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_a = scale_A.at(coord_0); + epilogue_params.scale_b = scale_B.at(coord_0); + epilogue_params.scale_c = scale_C.at(coord_0); + epilogue_params.scale_d = scale_D.at(coord_0); + } + + if constexpr (IsBiasEnabled or IsDeBiasEnabled) { + epilogue_params.Bias = Bias; + } + + if constexpr (IsAbsMaxEnabledD) { + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + } + + if constexpr (IsAuxInEnabled) { + epilogue_params.Aux = Aux; + } + + if constexpr (IsAuxOutEnabled) { + epilogue_params.Aux = Aux; + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_aux = scale_Aux.at(coord_0); + } + if constexpr (IsAbsMaxEnabledAux) { + epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); + } + } + + if constexpr (IsPerRowScaleEnabled) { + epilogue_params.Valpha = Valpha; + if (vector_scale_mode == VectorScale::ENABLED) { + epilogue_params.Vbeta = Vbeta; + } + } + + if constexpr (IsBlockScaleSupported) { + epilogue_params.SfD = SfD; + epilogue_params.st = norm_constant.at(coord_0); + } + + return epilogue_params; + } +}; + +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB +> +struct TestbedImpl { + // Kernel data types + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type + using HostCollectiveMainloopType = HostCollectiveMainloop; + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, + HostCollectiveEpilogue>; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; + using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; + using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; + using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + uint32_t sm_count; + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr uint32_t mma_promotion_interval = 4; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + HostCollectiveMainloopType collective_mma_inputs; + CollectiveEpilogue collective_epilogue; + + static constexpr bool IsGroupGemm = CollectiveEpilogue::IsGroupGemm; + + // + // Methods + // + + TestbedImpl( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + /// Initializes data structures + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + collective_mma_inputs.initialize(problem_shapes); + collective_epilogue.initialize(problem_shapes, alpha_, beta_); + + return true; + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) + { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + + bool passed = collective_mma_inputs.compare_reference(problem_shapes, batch); + passed &= collective_epilogue.compare_reference(problem_shapes, alpha, beta, batch); + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << batch << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << batch + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + collective_mma_inputs.print_tensors(file, batch); + collective_epilogue.print_tensors(file, batch); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta) + { + using namespace cute; + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + bool passed = true; + for (int32_t i = 0; i < L; ++i) { + auto mainloop_params = collective_mma_inputs.to_host_args(problem_shapes, i); + auto epilogue_params = collective_epilogue.to_host_args(problem_shapes, i); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + passed &= compare_reference(problem_shapes, alpha, beta, i); + } + return passed; + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } + + return true; + } + + /// Executes one test + bool run( + ProblemShapeType problem_shapes, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + detail::Iterations iterations = detail::Iterations{} + ) + { + + // Fail test if insufficient CUDA device + if (!sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + + if (!this->initialize(problem_shapes, alpha, beta)) { + std::cerr << "Initialization failed \n"; + return false; + } + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; + + typename HostCollectiveMainloopType::Arguments mainloop_args; + + mainloop_args = collective_mma_inputs.to_args(problem_shapes); + + if constexpr (IsGroupGemm) { + arguments = + { + cutlass::gemm::GemmUniversalMode::kGrouped, + problem_shapes, + mainloop_args, + collective_epilogue.to_args(problem_shapes), + hw_info + }; + } + else { + arguments = + { + cutlass::gemm::GemmUniversalMode::kArray, + problem_shapes, + mainloop_args, + collective_epilogue.to_args(problem_shapes), + hw_info + }; + } + + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return false; + } + + // + // Run the GEMM + // + + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_shapes, alpha, beta); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta + << "\n"; + } + + return passed; + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB +> +struct Testbed3x { + + using TestBedImpl = typename detail::TestbedImpl< + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, + ElementB + >; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + static constexpr bool IsGroupGemm = TestBedImpl::IsGroupGemm; + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3x( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_shapes, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + detail::Iterations iterations = detail::Iterations{} + ) + { + return impl_.run( + problem_shapes, alpha, beta, iterations); + } +}; + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed(check_relative_equality, ScalarLoc::ON_DEVICE, VectorScale::DISABLED); + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + int batches[] = {5, 10}; + + bool passed = true; + + for (int batch : batches) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + + if constexpr (Testbed3x::IsGroupGemm) { + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + for (int i = 0; i < batch; ++i) { + problem_sizes_host.push_back({m * ((i % 3) + 1), n * ((i % 4) + 1), k * ((i % 5) + 1)}); + } + + problem_sizes_device.reset(problem_sizes_host.size()); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + passed = testbed.run( + ProblemShapeType{static_cast(problem_sizes_host.size()), problem_sizes_device.get(), problem_sizes_host.data()}, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + else { + ProblemShapeType problem_size{{m, n, k, batch}}; + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << batch << " FAILED.\n"; + return false; + } + } // k + } // n + } // m + } // batch + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSmall(double alpha = 1.0, double beta = 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ElementA = typename Gemm::GemmKernel::ElementA; + using ElementB = typename Gemm::GemmKernel::ElementB; + using TiledMma = typename Gemm::GemmKernel::TiledMma; + + static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4(); + // For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. + int alignment_bits_a = cutlass::detail::get_input_alignment_bits(); + int alignment_input_a = (alignment_bits_a / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_a / cute::sizeof_bits::value); + + int alignment_bits_b = cutlass::detail::get_input_alignment_bits(); + int alignment_input_b = (alignment_bits_b / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_b / cute::sizeof_bits::value); + + int alignment_input = (alignment_input_a == 0 || alignment_input_b == 0) ? 0 : std::max(alignment_input_a, alignment_input_b); + + if constexpr (apply_alignment_offset) { + // If BlockScaled, then min alignment is SFVecSize + static constexpr bool IsBlockScaleSupported = Gemm::EpilogueOutputOp::IsBlockScaleSupported; + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + if constexpr (IsBlockScaleSupported) { + alignment_input = cutlass::round_up(alignment_input, SFVecSize); + } + } + + + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + CtaShape_MNK cta_shape; + Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); + // For Ptr-Array and Grouped GEMM ideally we need to know SM count at runtime + static constexpr int SmCount = 16; + + float waves[] = {0.5, 2.5}; + int batches[] = {3}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + // this is to test with min alignment + problem_size_k = {256 - alignment_input, 512 + alignment_input}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + bool passed = true; + + for (int batch : batches) { + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + float num_grid = wave * SmCount; + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = static_cast(num_grid) / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = static_cast(num_grid) / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) - alignment_input; // this is just to test with unusual problem shapes + int n = grid_n * cute::size<1>(cta_shape) + alignment_input; + + if constexpr (Testbed3x::IsGroupGemm) { + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + for (int i = 0; i < batch; ++i) { + problem_sizes_host.push_back({m * ((i % 2) + 1), n * ((i % 3) + 1), k * ((i % 2) + 1)}); + } + problem_sizes_device.reset(problem_sizes_host.size()); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + ProblemShapeType problem_shapes{batch, problem_sizes_device.get(), problem_sizes_host.data()}; + + if (CUTLASS_DEBUG_TRACE_LEVEL > 0) { + for (int i = 0; i < batch; ++i) { + std::cout << "problem_shapes : " << problem_shapes.get_host_problem_shape(i) << " \n"; + } + } + passed = testbed.run( + problem_shapes, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + else { + ProblemShapeType problem_shapes{{m, n, k, batch}}; + if (CUTLASS_DEBUG_TRACE_LEVEL > 0) { + std::cout << "problem_shapes : " << problem_shapes.get_host_problem_shape() << " \n"; + } + passed = testbed.run( + problem_shapes, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // k + } // waves + } // batches + + return passed; +} + +template +bool TestSmallFusion(double alpha = 1.0, double beta = 0.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED) { + return TestSmall( + alpha, beta, check_relative_equality, use_device_scalars, vector_scale_mode); +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b00f98a97846de175f1c6f95919c483ab4b81da --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp @@ -0,0 +1,515 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with elementwise tensor-tensor broadcast epilogue +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "testbed_utils.h" +#include "gemm_testbed_3x.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Testbed3xTensorBroadcast { + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementA = typename Kernel::ElementA; + using StrideA = typename Kernel::StrideA; + using ElementB = typename Kernel::ElementB; + using StrideB = typename Kernel::StrideB; + using ElementC = typename Kernel::ElementC; + using StrideC = typename Kernel::StrideC; + using ElementD = typename Kernel::ElementD; + using StrideD = typename Kernel::StrideD; + + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementCompute = typename Epilogue::ElementCompute; + using ElementScalar = typename Epilogue::ElementScalar; + using ProblemShapeType = typename Kernel::ProblemShape; + using ElementBias = typename Epilogue::ElementBias; + using ActivationFunctor = typename Epilogue::ActivationFunctor; + + static constexpr bool IsBinaryOp0Enabled = Epilogue::IsBinaryOp0Enabled; + static constexpr bool IsBinaryOp1Enabled = Epilogue::IsBinaryOp1Enabled; + static constexpr bool IsUnaryOpEnabled = Epilogue::IsUnaryOpEnabled; + + static constexpr bool PerColBias = Epilogue::PerColumnBias; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + cutlass::HostTensor bias; + cutlass::HostTensor tensor_C1; + // tensor_C0 is taken from TestbedImpl's tensor_C + + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3xTensorBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_) { } + + Testbed3xTensorBroadcast( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, + stride_factor_B_, + stride_factor_C_, + stride_factor_D_, + CheckEquality::EXACT, ScalarLoc::ON_HOST, VectorScale::ENABLED, + init_A_, + init_B_, + init_C_, + cutlass::Distribution::Uniform, + cutlass::Distribution::Uniform, + seed_) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B/C/D tensor + // + impl_.initialize(problem_size); + } + + void initialize_bias(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto bias_size = PerColBias ? cute::get<1>(problem_shape_MNKL) : cute::get<0>(problem_shape_MNKL); + bias.resize(cutlass::Coord<1>(bias_size)); + + EXPECT_TRUE(detail::initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2023)); + bias.sync_device(); + } + + void initialize_c1(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto c_coord = cutlass::make_Coord(M * L, N); + + tensor_C1.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C)); + EXPECT_TRUE(detail::initialize_tensor(tensor_C1.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2024)); + tensor_C1.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta, + bool use_bias) + { + auto [M, N, K, L] = problem_shape_MNKL; + + impl_.collective_epilogue.tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_B.host_view()), 0); + + if (impl_.collective_epilogue.tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.tensor_D.host_view()), 0); + } + + if (impl_.collective_epilogue.reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.reference_D.host_view()), 0); + } + + bool passed = cutlass::reference::host::TensorEquals(impl_.collective_epilogue.reference_D.host_view(), impl_.collective_epilogue.tensor_D.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_broadcast" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias + << ", per-col bias: " << PerColBias << "\n\n"; + + if (use_bias){ + file << "Bias = \n" << bias.host_view()<< "\n\n"; + } + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() + << "\nC0 =\n" << impl_.collective_epilogue.tensor_C.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\n\nReference =\n" << impl_.collective_epilogue.reference_D.host_view() + << "\n\nComputed =\n" <(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto D = cute::make_tensor(impl_.collective_epilogue.reference_D.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); + auto Bias = cute::make_tensor(static_cast(use_bias ? bias.host_data() : nullptr), + cute::make_layout(PerColBias ? cute::make_shape(1, N) : cute::make_shape(M, 1))); + auto C0 = cute::make_tensor(impl_.collective_epilogue.tensor_C.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + auto C1 = cute::make_tensor(tensor_C1.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + + // Create host workspace for output of testbed. This computes a portion of the epilogue: + // ref_compute_out = Activation(alpha * (A @ B) + bias) + cutlass::HostTensor ref_compute_out; + auto c_coord = cutlass::make_Coord(M * L, N); + ref_compute_out.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C), false); + auto RefComputeOut = cute::make_tensor(ref_compute_out.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + // Use a dummy null tensor for operand C because the epilogue overrides C. + auto dummy_C = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + ElementCompute dummy_beta(0); + auto dummy_Aux = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); + auto dummy_Valpha = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); + auto dummy_Vbeta = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); + + auto dummy_SFD = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + using DummySFDVectorSize = cute::Int<0>; + + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(dummy_C), + decltype(RefComputeOut), + decltype(Bias), + decltype(dummy_Aux), + decltype(dummy_Valpha), + decltype(dummy_Vbeta), + ActivationFunctor, + decltype(dummy_SFD), + DummySFDVectorSize, + cutlass::plus, + PerColBias> epilogue_params{ + alpha, + dummy_beta, + dummy_C, + RefComputeOut, + Bias, + dummy_Aux, + dummy_Valpha, + dummy_Vbeta + }; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + cutlass::NumericConverter source_converter; + cutlass::NumericConverter destination_converter; + cutlass::multiplies mul; + + // Compute broadcast operations atop the reference + #pragma omp parallel for collapse(3) + for (int64_t l = 0; l < cute::size<2>(A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(A.layout()); ++m) { + for (int64_t n = 0; n < cute::size<0>(B.layout()); ++n) { + ElementCompute intermediate = RefComputeOut(m, n, l); + // Apply BinaryOp0, if needed + if constexpr (IsBinaryOp0Enabled) { + typename Epilogue::ThreadEpilogueOp::BinaryOp0 bin0; + ElementCompute converted_source = source_converter(C0(m, n, l)); + intermediate = bin0(intermediate, mul(beta, converted_source)); + } + + // Apply BinaryOp1, if needed + if constexpr (IsBinaryOp1Enabled) { + typename Epilogue::ThreadEpilogueOp::BinaryOp1 bin1; + ElementCompute converted_source = source_converter(C1(m, n, l)); + intermediate = bin1(intermediate, mul(beta, converted_source)); + } + + // Apply UnaryOp, if needed + if constexpr (IsUnaryOpEnabled) { + typename Epilogue::ThreadEpilogueOp::UnaryOp unary; + intermediate = unary(intermediate); + } + + D(m, n, l) = destination_converter(intermediate); + } + } + } + + return compare_reference(problem_shape_MNKL, alpha, beta, use_bias); + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + int iterations = 20, + bool use_bias = true) + { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + /// Initializes data structures + /// A/B/C0/D Tensor + initialize(problem_size); + initialize_bias(problem_size); + + if constexpr (IsBinaryOp1Enabled) { + initialize_c1(problem_size); + } + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b, + impl_.mma_promotion_interval + }, + { // Epilogue arguments + { alpha, beta }, // ThreadOp arguments + impl_.collective_epilogue.stride_c, + impl_.collective_epilogue.tensor_D.device_data(), + impl_.collective_epilogue.stride_d, + use_bias ? bias.device_data() : nullptr, + impl_.collective_epilogue.tensor_C.device_data(), + tensor_C1.device_data() + }, // Epilogue arguments end + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, alpha, beta, use_bias); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << float(alpha) + << ", beta: " << float(beta) + << ", use_bias: " << use_bias + << "\n"; + } + + return passed; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllTensorBroadcast(bool use_bias=true) { + using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3xTensorBroadcast testbed; + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (bool use_bias : {true, false}) { + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(1), + false, // profiling + 20, // iterations + use_bias + ); + + if (!passed) { + return false; + } + } + } + } + } + + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(1), + false, // profiling + 20 // iterations + ); + if (!passed) { + return false; + } + } + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..6ae7b864cb272782da4920ffc038830d3b5984b2 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct MultistageTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = + typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + MultistageTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {} + + /// Helper to initialize a tensor view + template + bool initialize_tensor(cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, uint64_t seed) { + if (dist_kind == cutlass::Distribution::Uniform) { + int scope = (cutlass::sizeof_bits::value == 8) ? 2 : 8; + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, + -scope, 0); + } else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, -1); + } else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), + view.capacity()); + } else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Waives test if CUDA device is insufficient + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run(cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waives test if CUDA device is insufficient + if (!sufficient()) { + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor + tensor_A(problem_size.mk()); + + cutlass::HostTensor + tensor_B(problem_size.kn()); + + cutlass::HostTensor + tensor_C(problem_size.mn()); + + cutlass::HostTensor + tensor_D(problem_size.mn()); + + cutlass::HostTensor + reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, tensor_A.device_ref(), tensor_B.device_ref(), + tensor_C.device_ref(), tensor_D.device_ref(), {alpha, beta}}; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, alpha, tensor_A.host_ref(), tensor_B.host_ref(), beta, + reference_D.host_ref(), ElementAccumulator(0)); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + + fname << "error_Gemm_device_" << problem_size.m() << "x" + << problem_size.n() << "x" << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" << Gemm::ThreadblockShape::kN + << "x" << Gemm::ThreadblockShape::kK << "_" << Gemm::WarpShape::kM + << "x" << Gemm::WarpShape::kN << "x" << Gemm::WarpShape::kK + << ".txt"; + + std::ofstream file(fname.str()); + + file << "problem: " << problem_size << ", alpha: " << alpha + << ", beta: " << beta << "\n\n"; + + file << "A =\n" + << tensor_A.host_view() << "\nB =\n" + << tensor_B.host_view() << "\nC =\n" + << tensor_C.host_view() << "\n\nReference =\n" + << reference_D.host_view() << "\nComputed =\n" + << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = {16, 528}; + + int problem_size_n[] = {16, 528}; + + int problem_size_k[] = {Gemm::InstructionShape::kK, + Gemm::ThreadblockShape::kK * Gemm::kStages + + Gemm::InstructionShape::kK}; + + double problem_alpha[] = {1.0}; + + // TODO Try non zero beta value after multistaged epilogue is implemented + double problem_beta[] = {0.0}; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + passed = + run({m, n, k}, ElementCompute(alpha), ElementCompute(beta)); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..e309208bb4311253be5b7366841164eb62748bab --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct MultistageInterleavedTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + MultistageInterleavedTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm::ElementA, + typename Gemm::LayoutA> tensor_A(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_C(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_D(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reorder_column( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); + + cutlass::reference::host::TensorCopy( + reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), + tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nB_reordered =\n" << tensor_B_reordered.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = { + InterleavedK, 512 + InterleavedK + }; + + int problem_size_n[] = { + InterleavedK, 512 + InterleavedK + }; + + int problem_size_k[] = { + InterleavedK, Gemm::ThreadblockShape::kK * Gemm::kStages + InterleavedK + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 0.0 + }; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = run( + {m, n, k}, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py new file mode 100644 index 0000000000000000000000000000000000000000..a180028205abb689436c73403eea82758ade7da9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py @@ -0,0 +1,341 @@ +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# this file creates the test/unit/gemm/device simt tests + + +outputDir = "" + +################################################################################ +# parameters +# Edge - for tiles, the edges represent the length of one side +# Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles +# MaxEdge - maximum length of each edge +# Min/Max - minimum/maximum of the product of edge lengths +################################################################################ + +warpsPerThreadblockEdge = [1, 2, 4, 8, 16] +warpsPerThreadblockRatio = 2 +warpsPerThreadblockMax = 16 +# NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases + +warpShapeEdges = [8, 16, 32, 64, 128, 256] +warpShapeRatio = 4 +warpShapeMax = 64*64 +warpShapeMin = 8*8 + +threadblockEdgeMax = 256 + +# char, type bits/elem, max tile, L0 threadblock tiles +precisions = [ + ["c", "cutlass::complex", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], + ["q", "cutlass::Quaternion", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], + ["d", "double", 64, 64*64, [ [ 64, 64], [ 32, 32] ] ], + ["h", "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ], + ["i", "int", 32, 128*128, [ [128, 64], [ 16, 32] ] ], + ["s", "float", 32, 128*128, [ [128, 256], [128, 128], [ 64, 64] ] ], + ["z", "cutlass::complex", 128, 64*64, [ [ 32, 64], [ 16, 32] ] ], + ] +# L1 will have a single kernel for every unique shape +# L2 will have everything else + +transposes = [ + [False, False], + [False, True], + [True, False], + [True, True] + ] + +################################################################################ +# warps per threadblock +################################################################################ +warpsPerThreadblocks = [] +for warpsPerThreadblock0 in warpsPerThreadblockEdge: + for warpsPerThreadblock1 in warpsPerThreadblockEdge: + if warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio and warpsPerThreadblock1 / warpsPerThreadblock0 <= warpsPerThreadblockRatio and warpsPerThreadblock0 * warpsPerThreadblock1 <= warpsPerThreadblockMax: + warpsPerThreadblocks.append([warpsPerThreadblock0, + warpsPerThreadblock1]) +print("WarpsPerThreadblocks",warpsPerThreadblocks) + +################################################################################ +# warp shapes +################################################################################ +warpNumThreads = 32 +warpShapes = [] +for warp0 in warpShapeEdges: + for warp1 in warpShapeEdges: + if warp0 / warp1 <= warpShapeRatio and warp1 / warp0 <= warpShapeRatio and warp0*warp1 <= warpShapeMax and warp0*warp1 > warpShapeMin: + warpShapes.append([warp0, warp1]) +print("WarpShapes", warpShapes) + +numL0 = 0 +numL1 = 0 +numL2 = 0 + +################################################################################ +# create kernels +# create a file for each precision/transpose +# each file contains many tile sizes +################################################################################ + +# precisions +for precision in precisions: + + # get precision char + precisionChar = precision[0] + precisionType = precision[1] + precisionBits = precision[2] + threadblockMaxElements = precision[3] + threadblockTilesL0 = precision[4] + + # transposes + for transpose in transposes: + + # get transpose char + columnMajorA = transpose[0] + columnMajorB = transpose[1] + transCharA = "n" if columnMajorA else "t" + transCharB = "n" if columnMajorB else "t" + + # open file + fileName="simt_%sgemm_%s%s_sm50.cu" % (precisionChar, transCharA, transCharB) + print("\n", fileName) + filePath = "%s%s" % (outputDir, fileName) + out = open(filePath, "w+") + + # write file header + out.write("/***************************************************************************************************\n" +" * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n" +" * SPDX-License-Identifier: BSD-3-Clause \n" +" * \n" +" * Redistribution and use in source and binary forms, with or without \n" +" * modification, are permitted provided that the following conditions are met: \n" +" * \n" +" * 1. Redistributions of source code must retain the above copyright notice, this \n" +" * list of conditions and the following disclaimer. \n" +" * \n" +" * 2. Redistributions in binary form must reproduce the above copyright notice, \n" +" * this list of conditions and the following disclaimer in the documentation \n" +" * and/or other materials provided with the distribution. \n" +" * \n" +" * 3. Neither the name of the copyright holder nor the names of its \n" +" * contributors may be used to endorse or promote products derived from \n" +" * this software without specific prior written permission. \n" +" * \n" +" * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" \n" +" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE \n" +" * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE \n" +" * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE \n" +" * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL \n" +" * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR \n" +" * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER \n" +" * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, \n" +" * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE \n" +" * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \n" +" *\n" +" **************************************************************************************************/\n" +"/*! \\file\n" +" \\brief Tests for device-wide GEMM interface\n" +"*/\n" +"\n" +"#include \n" +"\n" +"#include \"cutlass/cutlass.h\"\n" +"#include \"cutlass/gemm/device/gemm.h\"\n" +"#include \"cutlass/numeric_types.h\"\n" +"\n" +"#include \"../../common/cutlass_unit_test.h\"\n" +"\n" +"#include \"cutlass/util/host_tensor.h\"\n" +"#include \"cutlass/util/tensor_view_io.h\"\n" +"#include \"cutlass/util/reference/host/tensor_fill.h\"\n" +"#include \"cutlass/util/reference/host/tensor_copy.h\"\n" +"#include \"cutlass/util/reference/host/tensor_compare.h\"\n" +"#include \"cutlass/util/reference/host/gemm.h\"\n" +"\n" +"#include \"testbed.h\"\n" +"\n") + foundThreadblockTilesL0 = {} + foundThreadblockTilesL1 = {} + + ######################################################################## + # for each combination of tile sizes + ######################################################################## + for warpsPerThreadblock in warpsPerThreadblocks: + for warpShape in warpShapes: + warpThreadsM = 0 + if warpShape[0] > warpShape[1]: + warpThreadsM = 8 + else: + warpThreadsM = 4 + warpThreadsN = warpNumThreads / warpThreadsM + + # skip shapes with conflicting rectangularity + # they are unlikely to be fastest + blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1] + blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1] + warpG = warpShape[0] > warpShape[1] + warpL = warpShape[0] < warpShape[1] + + blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1]*2 + blockL2 = warpsPerThreadblock[0]*2 < warpsPerThreadblock[1] + warpG2 = warpShape[0] > warpShape[1]*2 + warpL2 = warpShape[0]*2 < warpShape[1] + + if blockG2 and warpL: continue + if blockL2 and warpG: continue + if warpG2 and blockL: continue + if warpL2 and blockG: continue + + # check threadblock ratios and max + threadblockTile = [warpShape[0]*warpsPerThreadblock[0], + warpShape[1]*warpsPerThreadblock[1]] + if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: continue + if threadblockTile[0] > threadblockEdgeMax: continue + if threadblockTile[1] > threadblockEdgeMax: continue + totalThreads = warpNumThreads*warpsPerThreadblock[0]*warpsPerThreadblock[1] + + # calculate unroll + # ensure that every iteration at least a full load of A,B are done + unrollMin = 8 + unrollMin0 = totalThreads / threadblockTile[0] + unrollMin1 = totalThreads / threadblockTile[1] + unroll = max(unrollMin, unrollMin0, unrollMin1) + + threadTileM = warpShape[0] / warpThreadsM + threadTileN = warpShape[1] / warpThreadsN + if threadTileM < 2 or threadTileN < 2: continue + if threadTileM*threadTileN*precisionBits > 8*8*32: continue + + # epilogue currently only supports N < WarpNumThreads + if threadblockTile[1] < warpNumThreads: continue + + # limit smem + smemBitsA = threadblockTile[0]*unroll*2*precisionBits + smemBitsB = threadblockTile[1]*unroll*2*precisionBits + smemKBytes = (smemBitsA+smemBitsB)/8/1024 + if (smemKBytes > 48): continue + + # test level 0 + testLevel = -1 + for tileId in range(0, len(threadblockTilesL0)): + tbTile = threadblockTilesL0[tileId] + if tbTile[0] == threadblockTile[0] and tbTile[1] == threadblockTile[1]: + if tuple(tbTile) not in foundThreadblockTilesL0: + testLevel = 0 + numL0 += 1 + foundThreadblockTilesL0[tuple(tbTile)] = True + + # test level 1 + if testLevel < 0: + threadblockTileAlreadyUsed = False + if tuple(threadblockTile) not in foundThreadblockTilesL1: + testLevel = 1 + numL1 += 1 + foundThreadblockTilesL1[tuple(threadblockTile)] = True + + # test level 2 + if testLevel < 0: + testLevel = 2 + numL2 += 1 + + ################################################################ + # write this tile to file + ################################################################ + + print("%ix%ix%i__%ix%i_%ix%i_%ix%i L%i" % ( + threadblockTile[0], threadblockTile[1], unroll, + threadTileM, threadTileN, + warpThreadsM, warpThreadsN, + warpsPerThreadblock[0], warpsPerThreadblock[1], testLevel)) + + out.write("////////////////////////////////////////////////////////////////////////////////\n" + "// Elements / Thread: %3i x %3i\n" + "// Threads / Warp: %3i x %3i\n" + "// Warps / Block: %3i x %3i\n" + "// Threadblock: %3i x %3i x %2i\n" + % ( threadTileM, threadTileN, + warpThreadsM, warpThreadsN, + warpsPerThreadblock[0], warpsPerThreadblock[1], + threadblockTile[0], threadblockTile[1], unroll + ) + ) + + out.write("CUTLASS_TEST_L%i(SM50_device_%sgemm_%s%s, %ix%ix%i_%ix%ix1_%ix%i_%ix%i_%ix%i, {\n" % ( + testLevel, + precisionChar, + transCharA, + transCharB, + threadblockTile[0], + threadblockTile[1], + unroll, + warpShape[0], + warpShape[1], + threadTileM, + threadTileN, + warpThreadsM, + warpThreadsN, + warpsPerThreadblock[0], + warpsPerThreadblock[1] + )) + out.write(" using precision = %s;\n" % precisionType) + out.write(" using ThreadblockShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n" % ( + threadblockTile[0], + threadblockTile[1], + unroll)) + out.write(" using WarpShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n\n" % ( + warpShape[0], + warpShape[1], + unroll)) + out.write(" static int const kEpilogueElementsPerAccess = 1;\n" + " using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;\n" + " using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<\n" + " precision, kEpilogueElementsPerAccess, precision, precision>;\n\n") + + out.write(" using Gemm = cutlass::gemm::device::Gemm<\n" + " precision, cutlass::layout::%sMajor,\n" + " precision, cutlass::layout::%sMajor,\n" + " precision, cutlass::layout::RowMajor,\n" + " precision,\n" + " cutlass::arch::OpClassSimt,\n" + " cutlass::arch::Sm50,\n" + " ThreadblockShape, WarpShape, InstructionShape,\n" + " EpilogueOutputOp,\n" + " cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,\n" + " 2 // Stages\n" + " >;\n" % ( + "Column" if columnMajorA else "Row", + "Column" if columnMajorB else "Row", + )) + out.write(" EXPECT_TRUE(test::gemm::device::TestAllGemm());\n" + "} )\n\n") + + + out.close() +print("NumKernels:", numL0, numL1, numL2) + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp new file mode 100644 index 0000000000000000000000000000000000000000..63ffc3281dd2b9e9f74e0024c73da00628331dd4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Host reference and operations for Sm90 EVT unit test +*/ +#pragma once +#include "gemm_testbed_3x_evt.hpp" + +////////////////////////////////////////////////////////////////////////////// +/// Host references used for testing +namespace test::gemm::device { +template +using HEVT = HostTreeVisitor; + +template +using HDAG = HostTopoVisitor; + +template +using HST = HostSplitTreeVisitor; + +/// D = alpha * acc + beta * C + AuxLoad +template +class HostEVTAuxLoad { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using AuxLoadNode = HostAuxLoad; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, AuxLoadNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// D = alpha * acc + beta * C + per-column bias +template +class HostPerColBias { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using RowBroadcastNode = HostRowBroadcast; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, RowBroadcastNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// D = beta * C + Graph(relu(alpha * acc + aux) + aux) +/// Testing EVT - DAG structure +template +class HostEVTDAG { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using AuxLoadNode = HostAuxLoad; + using DAGNode = HDAG< + float, + cute::tuple< + cute::tuple<>, // 0. alpha + cute::tuple<>, // 1. acc + cute::tuple<>, // 2. aux load + cute::tuple, // 3. alpha * acc + aux load + cute::tuple, // relu(alpha * acc + aux load) + cute::tuple // relu(alpha * acc + aux load) + aux load + >, + ScalarAlpha, + AccFetchNode, + AuxLoadNode, + HostCompute, + HostCompute, + HostCompute + >; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, DAGNode>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// EVT = alpha * acc + C +/// D = Graph(maximum(EVT + per-row bias, EVT)) +/// Testing DAG - EVT +template +class HostDAGEVT { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using EVTNode = HEVT< + HostAuxStore, + HEVT< + HostCompute, + HostScalarBroadcast<2>, + HostAccumulator<>, + HostAuxLoad + > + >; + using EVTModule = HEVT< + HostAuxStore, + HDAG< + float, + cute::tuple< + cute::tuple<>, // 0. EVT + cute::tuple<>, // 1. per-row bias + cute::tuple, // 2. EVT + per-row bias + cute::tuple // 3. maximum(EVT + per-row bias, EVT) + >, + EVTNode, + HostColBroadcast>, + HostCompute, + HostCompute + > + >; +}; + +/// Xreduce(alpha * acc + beta * C) +template +class HostReduce { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using BinaryCompute0 = HEVT, ScalarAlpha, AccFetchNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, BinaryCompute0>; + using ReduceNode = HEVT; + using EVTModule = HEVT, ReduceNode>; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template class ActivationFn, class ElementD> +class HostScaledLinCombPerRowBiasEltAct { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using EVTModule = HEVT< + HostAuxStore, + HEVT< + HostCompute::template Op>, // activation(Z) * scaled_d + HEVT< + HostCompute, // activation(Z) + HEVT< + HostCompute, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + > + >, + HostScalarBroadcast<1> // scale_d + > + >; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template class ActivationFn, class ElementD, class ElementAux = ElementD> +class HostScaledLinCombPerRowBiasEltActAmaxAux { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + template + using amax = cutlass::maximum_absolute_value_reduction; + using EVTModuleAuxFp8 = HEVT< + HostAuxStore, + HST, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + >, + // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) + HEVT< + HostCompute::template Op>, + HEVT< + HostScalarReduce, + HEVT< + HostCompute, //activation(Z) * scaled_d + HostAccumulator<> // Z + > + >, + HostScalarBroadcast<1> // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + HEVT< + HostAuxStore, + HEVT< + HostCompute, + HEVT< + HostScalarReduce, + HostAccumulator<> + >, + HostScalarBroadcast<1> + > + > + > + >; + + using EVTModuleAuxNotFp8 = HEVT< + // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) + HostAuxStore, + HEVT< + HostCompute::template Op>, + HEVT< + HostScalarReduce, + HEVT< + HostCompute, //activation(Z) * scaled_d + HEVT< + // Aux = Z + HostAuxStore, + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + HEVT< + HostCompute, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + > + > + > + >, + HostScalarBroadcast<1> // scale_d + > + >; + + using EVTModule = cute::conditional_t, EVTModuleAuxFp8, EVTModuleAuxNotFp8>; + +}; +} // namespace test::gemm::device + +////////////////////////////////////////////////////////////////////////////// +namespace cutlass::epilogue { +namespace fusion { + +namespace detail { + +template +struct maximum_with_default_nan_propagation : maximum {}; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + AuxLoad +template< + class EpilogueDescriptor, + class AuxLoadDescriptor, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombAuxLoad = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad< + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, + typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, + typename AuxLoadDescriptor::CopyOpS2R // aux load + > + > + >; + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + AuxLoadNoSmem +template< + class EpilogueDescriptor, + class ElementAux, + class StrideAux, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombAuxLoadNoSmem = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad<0, void, ElementAux, StrideAux, void, void> // aux load + > + >; + +////////////////////////////////////////////////////////////////////////////// +/// Example DAG +/// beta * C + Graph(alpha * acc + gamma + acc) +template< + typename EpilogueDescriptor, + typename AuxLoadDescriptor, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEVTDAG = + Sm90EVT, // beta * C + (alpha * acc + aux) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90TopologicalVisitor< + ElementCompute, + cute::tuple< + cute::seq<>, // 0. alpha + cute::seq<>, // 1. acc + cute::seq<>, // 2. aux load + cute::seq<1, 0, 2>, // 3. alpha * acc + aux load + cute::seq<3>, // relu(alpha & acc + aux load) + cute::seq<2, 4> // relu(alpha * acc + aux load) + aux load + >, + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad< + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride, + typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>, + Sm90Compute, + Sm90Compute, + Sm90Compute + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// Example DAG +/// EVT = alpha * acc + C +/// D = Graph(maximum(EVT + per-row bias, EVT)) +template< + class EpilogueDescriptor, + class AuxStoreDescriptor, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDAGEVT = + Sm90TopologicalVisitor< + ElementCompute, + cute::tuple< + cute::seq<>, + cute::seq<>, + cute::seq<1, 0>, + cute::seq<0, 2> + >, + Sm90EVT< + Sm90AuxStore< + AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride, + typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>, + Sm90EVT, + Sm90ScalarBroadcast, + Sm90AccFetch, + Sm90SrcFetch + > + >, + Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias, ElementCompute>, + Sm90Compute, + Sm90Compute + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + per-column bias +template< + class EpilogueDescriptor, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColumnBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias, ElementCompute> + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = per-column reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColumnReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = per-row reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = scalar reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombScalarReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; +} // namespace fusion + +} // namespace cutlass::epilogue diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..0007666cdd084f35015200e36fd47f75971f6c1c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h @@ -0,0 +1,639 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_utils.h" +#include "testbed_universal.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Testbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + typename Gemm::LayoutA::Stride stride_factor_A; + typename Gemm::LayoutB::Stride stride_factor_B; + typename Gemm::LayoutC::Stride stride_factor_C; + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + Testbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + stride_factor_A(typename Gemm::LayoutA::Stride()), + stride_factor_B(typename Gemm::LayoutB::Stride()), + stride_factor_C(typename Gemm::LayoutC::Stride()), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + Testbed( + typename Gemm::LayoutA::Stride stride_factor_A_, + typename Gemm::LayoutB::Stride stride_factor_B_, + typename Gemm::LayoutC::Stride stride_factor_C_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_C(stride_factor_C_), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), stride_factor_A)); + tensor_B.resize(problem_size.kn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), stride_factor_B)); + tensor_C.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); + tensor_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); + reference_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0) + << "tensor_D (size " << tensor_D.size() << ") has nonpositive norm"; + } + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0) + << "reference_D (size " << reference_D.size() << ") has nonpositive norm"; + } + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << "reference_D does not equal tensor_D"; + + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + if (Relu) { + for (int i = 0; i < problem_size.m(); ++i) { + for (int j = 0; j < problem_size.n(); ++j) { + reference_D.at(cutlass::MatrixCoord(i, j)) = + ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) + ? (typename Gemm::ElementC)0 + : reference_D.at(cutlass::MatrixCoord(i, j)); + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { +/* + std::cout << "\n-----------------------\n"; + std::cout << "problem size: " << problem_size << "\n"; + std::cout << "split_k_slices: " << split_k_slices << "\n"; + std::cout << "alpha: " << alpha << "\n"; + std::cout << "beta: " << beta << "\n"; + std::cout << "-----------------------\n\n"; +*/ + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) + << "gemm_op.initialize returned with error " << to_string(status) + << ", indicating that this test is not supported. Last CUDA error: " + << cudaGetErrorString(cudaGetLastError()); + if (status != cutlass::Status::kSuccess) { + return true; + } + + // + // Run the GEMM + // + + try { + status = gemm_op(); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "gemm_op() threw a std::exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "gemm_op() threw an exception of unknown type"; + throw; + } + EXPECT_TRUE(status == cutlass::Status::kSuccess) + << "gemm_op failed with error " << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + EXPECT_TRUE(passed) << "Error: split_k_slices = " << split_k_slices + << ", alpha: " << alpha; + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmBasic( + const typename Gemm::LayoutA::Stride& stride_factor_A = typename Gemm::LayoutA::Stride(), + const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), + const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) { + bool passed = true; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; + + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + + int problem_size_k[] = { + kAlignmentK, Gemm::ThreadblockShape::kK * (Gemm::kStages + 1) - kAlignmentK}; + + int split_k_slices[] = { + 1, 2, 3 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + Testbed testbed(stride_factor_A, stride_factor_B, stride_factor_C); + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + try { + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestAllGemmBasic: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestAllGemmBasic: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: (unknown)"; + throw; + } + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemm( + const typename Gemm::LayoutA::Stride& stride_factor_A, + const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), + const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) +{ + // Test basic GEMM with non-default stride factors + return TestAllGemmBasic(stride_factor_A, stride_factor_B, stride_factor_C); +} + +template +bool TestAllGemm() +{ +#ifdef NDEBUG + // Non-debug builds also test basic GEMM with default stride factors + if (!TestAllGemmBasic()) { + return false; + } +#endif // NDEBUG + + // Test universal GEMM +#if 0 + // Define the universal kernel + using UniversalKernel = cutlass::gemm::kernel::GemmUniversal< + typename Gemm::GemmKernel::Mma, // Mma + typename Gemm::GemmKernel::Epilogue, // Epilogue + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> // ThreadblockSwizzle + >; +#else + // Define the streamk universal kernel + using UniversalKernel = cutlass::gemm::kernel::GemmUniversalStreamk< + typename Gemm::GemmKernel::Mma, // Mma + typename Gemm::GemmKernel::Epilogue, // Epilogue + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK // ThreadblockSwizzle + >; +#endif + + // Define the universal adaptor + using UniversalGemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Test universal GEMM + return TestAllGemmUniversal(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmPerf(int iterations = 1) { + bool passed = true; + + int problem_size_m[] = { 2048 }; + + int problem_size_n[] = { 4352 }; + + int problem_size_k[] = { 4096 }; + + int split_k_slices[] = { 1 }; + double problem_alpha[] = { 1 }; + double problem_beta[] = { 0.0 }; + + Testbed testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + for (int i = 0; i < iterations; i++){ + try { + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestGemmPerf: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestGemmPerf: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: (unknown)"; + throw; + } + } + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..add984ca3b9a0c05325b93cf52cbadd710527ba6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedComplex : public Testbed { + + using Base = Testbed; + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + + // + // Methods + // + + TestbedComplex( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + Base(init_A_, init_B_, init_C_, seed_) { } + + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex( + problem_size, + alpha, + this->tensor_A.host_ref(), + Gemm::kTransformA, + this->tensor_B.host_ref(), + Gemm::kTransformB, + beta, + this->tensor_C.host_ref(), + this->reference_D.host_ref(), + ElementAccumulator(0) + ); + + return this->compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Initialize workspace + // + + this->initialize(problem_size); + + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + this->tensor_A.device_ref(), + this->tensor_B.device_ref(), + this->tensor_C.device_ref(), + this->tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmComplex() { + bool passed = true; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = + cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + int problem_size_m[] = { + kAlignment, 512 - 3*kAlignment + }; + + int problem_size_n[] = { + kAlignment, 512 - 2*kAlignment + }; + + int problem_size_k[] = { + kAlignment, 128 - kAlignment + }; + + int split_k_slices[] = { + 1, 2, 3 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + TestbedComplex testbed; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..eca0b0ae0decf3293f6f73cb6ebbc5b5735a8e49 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -0,0 +1,670 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithBroadcastReferenceOp { + + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + GemmWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { + + ElementCompute t_full = binary_op(gemm, bias); + + if (OutputOp::kStoreT) { + T = ElementT(t_full); + } + + if (OutputOp::kStoreZ) { + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = GEMM(AB, C) +// +// T[i, j] = BinaryOp(Y[i, j], Broadcast[i]) +// +// Z[i, j] = Elementwise(T[i, j]) +// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +struct TestbedGemmWithBroadcast { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename OutputOp::ElementCompute; + using ElementVector = typename OutputOp::ElementVector; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; // Input A + cutlass::HostTensor tensor_B; // Input B + cutlass::HostTensor tensor_C; // Input C + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + + cutlass::HostTensor tensor_Z; + cutlass::HostTensor tensor_T; + + cutlass::HostTensor tensor_C_ref; + cutlass::HostTensor tensor_Y_ref; + cutlass::HostTensor tensor_Z_ref; + cutlass::HostTensor tensor_T_ref; + + + // + // Methods + // + + TestbedGemmWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 1; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_Z.resize(problem_size.mn()); + tensor_T.resize(problem_size.mn()); + tensor_Broadcast.resize({ + problem_size.m(), + 1 + }); + + tensor_C_ref.resize(problem_size.mn()); + tensor_Y_ref.resize(problem_size.mn()); + tensor_Z_ref.resize(problem_size.mn()); + tensor_T_ref.resize(problem_size.mn()); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Broadcast.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { + for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { + tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + + tensor_Z.sync_device(); + tensor_T.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + tensor_Z.sync_host(); + tensor_T.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (OutputOp::kStoreZ) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z_ref.host_view()), 0); + } + + if (OutputOp::kStoreT) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T_ref.host_view()), 0); + } + + bool passed = true; + float norm_diff = 0; + + if (OutputOp::kStoreZ) { + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Z_ref.host_view(), tensor_Z.host_view(), float()); + passed = (norm_diff <= 0.1f); + EXPECT_LT(norm_diff, 0.1f) << " tensor_Z is incorrect"; + } + + if (OutputOp::kStoreT) { + + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_T_ref.host_view(), tensor_T.host_view(), float()); + passed = (passed && (norm_diff <= 0.1f)); + + EXPECT_LT(norm_diff, 0.1f) << " tensor_T is incorrect"; + } + + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("errors_testbed_gemm_with_broadcast.txt"); + + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nZ =\n" << tensor_Z.host_view() + << "\nT =\n" << tensor_T.host_view() + << "\n\n" + << "\nY_ref =\n" << tensor_Y_ref.host_view() + << "\nZ_ref =\n" << tensor_Z_ref.host_view() + << "\nT_ref =\n" << tensor_T_ref.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + ElementAccumulator, typename Gemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C_ref.host_ref(), + tensor_Y_ref.host_ref(), + ElementAccumulator(0) + ); + + using ElementC = typename Gemm::ElementC; + + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int m = 0; m < problem_size.m(); ++m) { + for (int n = 0; n < problem_size.n(); ++n) { + + ElementZ z; + ElementT t; + + reference_op(z, t, tensor_Y_ref.at({m, n}), tensor_Broadcast.at({m, 0})); + + if (OutputOp::kStoreZ) { + tensor_Z_ref.at({m, n}) = z; + } + + if (OutputOp::kStoreT) { + tensor_T_ref.at({m, n}) = t; + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementAccumulator alpha = ElementAccumulator(1), + ElementAccumulator beta = ElementAccumulator(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_Z.device_data(), + tensor_Broadcast.device_data(), + tensor_T.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_Z.layout().stride(0), + 0, // This must be zero + tensor_T.layout().stride(0), + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = true; + + passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + // + // Profile + // + + #if 0 // profiling disabled for now. + + int const kWorkspaces = 100; + + cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Broadcast(tensor_Broadcast.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Z(tensor_Z.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_T(tensor_T.capacity() * kWorkspaces); + + cudaEvent_t events[2]; + for (auto & event : events) { + cudaError_t result = cudaEventCreate(&event); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); + return false; + break; + } + } + + int const kWarmupIterations = 5; + int const kProfilingIterations = 100; + + for (int i = 0; i < kWarmupIterations; ++i) { + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + + cudaError_t result = cudaEventRecord(events[0]); + EXPECT_EQ(result, cudaSuccess); + + for (int i = 0; i < kProfilingIterations; ++i) { + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), + profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), + profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), + profiling_tensor_Z.get() + tensor_Z.capacity() * (i % kWorkspaces), + profiling_tensor_Broadcast.get() + tensor_Broadcast.capacity() * (i % kWorkspaces), + profiling_tensor_T.get() + tensor_T.capacity() * (i % kWorkspaces), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_Z.layout().stride(0), + 0, // This must be zero + tensor_T.layout().stride(0), + }; + + gemm_op.initialize(arguments, workspace.get()); + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + result = cudaEventRecord(events[1]); + EXPECT_EQ(result, cudaSuccess); + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess); + + float elapsed_time = 0; + result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); + EXPECT_EQ(result, cudaSuccess); + + double average_time = double(elapsed_time) / double(kProfilingIterations); + + std::cout << problem_size << ": " << average_time << " ms" << std::endl; + + for (auto & event : events) { + cudaEventDestroy(event); + } + #endif + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +bool TestGemmWithBroadcast( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedGemmWithBroadcast testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +bool TestAllGemmWithBroadcast() { + + int M_problems[] = {8, 136, 264, 520}; + int N_problems[] = {8, 136, 264, 520}; + int K_problems[] = {8, 136, 264, 520}; + double alpha_problems[] = {1.25, 2.25}; + double beta_problems[] = {0, 1, 2.0}; + + bool passed = true; + + for (int M : M_problems) { + for (int N : N_problems) { + for (int K : K_problems) { + for (double alpha : alpha_problems) { + for (double beta : beta_problems) { + + TestbedGemmWithBroadcast testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + 1, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + EXPECT_TRUE(passed) + << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta; + + if (!passed) { + + return passed; + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..af3629ccfb87e09e80b85af508379780d6428dc5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithReductionReference { + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; + using ElementC = typename Gemm::ElementC; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + // + // Data members + // + + BinaryOp binary_op; + + // + // Methods + // + + GemmWithReductionReference() { } + + ElementCompute operator()( + ElementAccumulator d_y, + ElementT t) { + + return binary_op(ElementCompute(d_y), ElementCompute(t)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp +> +struct TestbedGemmWithReduction { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_Reduction; + cutlass::HostTensor tensor_Tensor; + cutlass::HostTensor tensor_C_ref; + cutlass::HostTensor reference_d_Y; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Reduction; + + // + // Methods + // + + TestbedGemmWithReduction( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 1; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + for (int m = 0; m < view.extent().row(); ++m) { + for (int n = 0; n < view.extent().column(); ++n) { + //view.at({m, n}) = Element(float(((idx ++) % 17) - 8)); + view.at({m, n}) = (n == 0 ? Element(m) : Element()); + + } + } + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + + tensor_Reduction.resize({ + problem_size.m(), + (problem_size.n() - 1 + Gemm::ThreadblockShape::kN) / Gemm::ThreadblockShape::kN + }); + + tensor_Tensor.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + reference_d_Y.resize(problem_size.mn(), false); + tensor_C_ref.resize(problem_size.mn(), false); + reference_Reduction.resize({problem_size.m(), 1}, false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Tensor.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { + for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { + tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_Reduction.sync_device(); + tensor_Tensor.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + tensor_Reduction.sync_host(); + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Reduction.host_view()), 0); + + bool passed = true; + for (int m = 0; m < tensor_Reduction.extent().row(); ++m) { + + ElementAccumulator reduced_value = ElementAccumulator(); + for (int j = 0; j < tensor_Reduction.extent().column(); ++j) { + reduced_value += tensor_Reduction.at({m, j}); + } + + if (reduced_value != reference_Reduction.at({m, 0})) { + std::cout << "Error in bias[" << m << "] - Expected: " << reference_Reduction.at({m, 0}) << ", got: " << reduced_value << std::endl; + passed = false; + break; + } + } + EXPECT_TRUE(passed) << "Reduction is incorect."; + + if (!cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view())) { + EXPECT_TRUE(false) << " mismatched reference"; + passed = false; + } + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors_sm70.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nT = \n" << tensor_Tensor.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view() + << "\n\nReduction =\n" << tensor_Reduction.host_view() << "\n" + << "\nReference reduction =\n" << reference_Reduction.host_view() << "\n"; + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + ElementAccumulator, typename Gemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C_ref.host_ref(), + reference_d_Y.host_ref(), + ElementAccumulator(0) + ); + + using ElementC = typename Gemm::ElementC; + + ReferenceOp reference_op; + + // compute backwards + for (int m = 0; m < problem_size.m(); ++m) { + ElementAccumulator reduced_value = ElementAccumulator(); + for (int n = 0; n < problem_size.n(); ++n) { + ElementAccumulator d_full = reference_op(reference_d_Y.at({m, n}), tensor_Tensor.at({m, n})); + reduced_value += d_full; + reference_D.at({m, n}) = ElementC(d_full); + } + reference_Reduction.at({m, 0}) = reduced_value; + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementAccumulator alpha = ElementAccumulator(1), + ElementAccumulator beta = ElementAccumulator(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_Reduction.device_data(), + tensor_Tensor.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_Reduction.layout().stride(0), + tensor_Tensor.layout().stride(0), + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + // + // Profile + // + + #if 0 // profiling disabled for now. + + int const kWorkspaces = 100; + + cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_D(tensor_D.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Reduction(tensor_Reduction.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Tensor(tensor_Tensor.capacity() * kWorkspaces); + + cudaEvent_t events[2]; + for (auto & event : events) { + cudaError_t result = cudaEventCreate(&event); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); + return false; + break; + } + } + + int const kWarmupIterations = 5; + int const kProfilingIterations = 100; + + for (int i = 0; i < kWarmupIterations; ++i) { + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + + cudaError_t result = cudaEventRecord(events[0]); + EXPECT_EQ(result, cudaSuccess); + + for (int i = 0; i < kProfilingIterations; ++i) { + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), + profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), + profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), + profiling_tensor_D.get() + tensor_D.capacity() * (i % kWorkspaces), + profiling_tensor_Reduction.get() + tensor_Reduction.capacity() * (i % kWorkspaces), + profiling_tensor_Tensor.get() + tensor_Tensor.capacity() * (i % kWorkspaces), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_Reduction.layout().stride(0), + tensor_Tensor.layout().stride(0), + }; + + gemm_op.initialize(arguments, workspace.get()); + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + result = cudaEventRecord(events[1]); + EXPECT_EQ(result, cudaSuccess); + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess); + + float elapsed_time = 0; + result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); + EXPECT_EQ(result, cudaSuccess); + + double average_time = double(elapsed_time) / double(kProfilingIterations); + + std::cout << problem_size << ": " << average_time << " ms" << std::endl; + + for (auto & event : events) { + cudaEventDestroy(event); + } + #endif + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmWithReduction( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count = 1, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedGemmWithReduction testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..c7317eb855477e63fe19858ca51cd5722f236eb5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h @@ -0,0 +1,501 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedGrouped { + + // + // Type definitions + // + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + int problem_count; + + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + // + // Methods + // + + TestbedGrouped( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // no fill - remain zero + } + + return true; + } + + /// Initializes data structures + void initialize() { + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + + lda_host.resize(problem_count); + ldb_host.resize(problem_count); + ldc_host.resize(problem_count); + ldd_host.resize(problem_count); + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + cutlass::gemm::GemmCoord problem( + 8 * (rand() % 64) + 24, + 8 * (rand() % 64) + 24, + 8 * (rand() % 64) + 24); + + if (!i) { + problem = cutlass::gemm::GemmCoord(48, 16, 8); + } + + problem_sizes_host.at(i) = problem; + + // std::cout << "Problem[" << i << "]: " << problem << std::endl; + + lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.m() * problem.k(); + int64_t elements_B = problem.k() * problem.n(); + int64_t elements_C = problem.m() * problem.n(); + int64_t elements_D = problem.m() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + // Random strides between problems? + } + + problem_sizes_device.reset(problem_count); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + lda.reset(problem_count); + ldb.reset(problem_count); + ldc.reset(problem_count); + ldd.reset(problem_count); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + std::vector ptr_A_host(problem_count); + std::vector ptr_B_host(problem_count); + std::vector ptr_C_host(problem_count); + std::vector ptr_D_host(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + + initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); + initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); + initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); + + cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); + cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); + cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); + cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); + } + } + + /// Verifies the result is a GEMM + bool verify( + ElementCompute alpha, + ElementCompute beta) { + + bool passed = true; + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); + cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); + cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); + cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); + cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference GEMM + cutlass::reference::host::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + alpha, + view_A, + Gemm::kTransformA, + view_B, + Gemm::kTransformB, + beta, + view_C, + view_Ref, + ElementAccumulator(0) + ); + + // Ensure that no input or output is entirely zero + EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + + // Compare against reference + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::ofstream file("testbed_grouped_errors.txt"); + + file + << "problem: " << problem << " [group: " << i << "]\n" + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << view_A + << "\nB =\n" << view_B + << "\nC =\n" << view_C + << "\n\nReference =\n" << view_Ref + << "\nComputed =\n" << view_D; + + return passed; + } + } + + return passed; + } + + /// Executes one test + bool run( + int problem_count, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->problem_count = problem_count; + + // Initialize the problem + initialize(); + + int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), problem_count); + + // Early exit + if (!threadblock_count) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; + } + return true; + } + + // Configure the GEMM arguments + typename EpilogueOutputOp::Params epilogue_op(alpha, beta); + + // Configure GEMM arguments + typename Gemm::Arguments args( + problem_sizes_device.get(), + problem_count, + threadblock_count, + epilogue_op, + ptr_A.get(), + ptr_B.get(), + ptr_C.get(), + ptr_D.get(), + lda.get(), + ldb.get(), + ldc.get(), + ldd.get(), + problem_sizes_host.data() + ); + + // Initialize the GEMM object + Gemm gemm; + + size_t workspace_size = gemm.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = gemm.initialize(args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Run the GEMM object + status = gemm.run(); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Wait for completion + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) + << "Kernel execution error: " << cudaGetErrorString(result); + + if (result != cudaSuccess) { + return false; + } + + // Verify correctness + return verify(alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h new file mode 100644 index 0000000000000000000000000000000000000000..f8f08f23c4477745648f1cf8f9e439ae6b5061e2 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h @@ -0,0 +1,502 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface + +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedGrouped { + + // + // Type definitions + // + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + using EpilogueOutputOp = typename Rank2K::EpilogueOutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Rank2K::LayoutA; + using LayoutB = typename Rank2K::LayoutB; + using LayoutC = typename Rank2K::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + int problem_count; + + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + // + // Methods + // + + TestbedGrouped( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // no fill - remain zero + } + + return true; + } + + /// Initializes data structures + void initialize() { + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + + lda_host.resize(problem_count); + ldb_host.resize(problem_count); + ldc_host.resize(problem_count); + ldd_host.resize(problem_count); + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + auto N = 8 * (rand() % 64) + 24; + auto K = 8 * (rand() % 64) + 24; + cutlass::gemm::GemmCoord problem(N, N, K); + + if (!i) { + problem = cutlass::gemm::GemmCoord(16, 16, 8); + } + + problem_sizes_host.at(i) = problem; + + lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.n() * problem.k(); + int64_t elements_B = problem.n() * problem.k(); + int64_t elements_C = problem.n() * problem.n(); + int64_t elements_D = problem.n() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + // Random strides between problems? + } + + problem_sizes_device.reset(problem_count); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + lda.reset(problem_count); + ldb.reset(problem_count); + ldc.reset(problem_count); + ldd.reset(problem_count); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + std::vector ptr_A_host(problem_count); + std::vector ptr_B_host(problem_count); + std::vector ptr_C_host(problem_count); + std::vector ptr_D_host(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.n(), problem.k()}; + MatrixCoord extent_B{problem.n(), problem.k()}; + MatrixCoord extent_C{problem.n(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + + initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); + initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); + initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); + + cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); + cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); + cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); + cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); + } + } + + /// Verifies the result is a Rank2K + bool verify( + ElementCompute alpha, + ElementCompute beta) { + + bool passed = true; + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.n(), problem.k()}; + MatrixCoord extent_B{problem.n(), problem.k()}; + MatrixCoord extent_C{problem.n(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); + cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); + cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); + cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); + cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference Rank2K + cutlass::reference::host::Rank2KComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + alpha, + view_A, + Rank2K::kTransformA, + view_B, + Rank2K::kTransformB, + beta, + view_C, + view_Ref, + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + // Ensure that no input or output is entirely zero + EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + + // Compare against reference + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::ofstream file("testbed_grouped_errors.txt"); + + file + << "problem: " << problem << " [group: " << i << "]\n" + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << view_A + << "\nB =\n" << view_B + << "\nC =\n" << view_C + << "\n\nReference =\n" << view_Ref + << "\nComputed =\n" << view_D; + + return passed; + } + } + + return passed; + } + + /// Executes one test + bool run( + int problem_count, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->problem_count = problem_count; + + // Initialize the problem + initialize(); + + int threadblock_count = Rank2K::sufficient(problem_sizes_host.data(), problem_count); + + // Early exit + if (!threadblock_count) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; + } + return true; + } + + // Configure the Rank2K arguments + typename EpilogueOutputOp::Params epilogue_op(alpha, beta); + + // Configure Rank2K arguments + typename Rank2K::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_sizes_device.get(), + problem_count, + threadblock_count, + epilogue_op, + ptr_A.get(), + ptr_B.get(), + ptr_C.get(), + ptr_D.get(), + lda.get(), + ldb.get(), + ldc.get(), + ldd.get(), + problem_sizes_host.data() + ); + + // Initialize the Rank2K object + Rank2K rank2k; + + size_t workspace_size = rank2k.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = rank2k.initialize(args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Run the Rank2K object + status = rank2k.run(); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Wait for completion + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) + << "Kernel execution error: " << cudaGetErrorString(result); + + if (result != cudaSuccess) { + return false; + } + + // Verify correctness + return verify(alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..e9315e12e8711f50256e4cfe05666201acd614d3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h @@ -0,0 +1,461 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K problem visitors +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/device_kernel.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Use simple problem visitor as a baseline +template +struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { + using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static cutlass::FillMode const kFillModeC = FillModeC; + + struct SharedStorage {}; + + int32_t tile_count_sum; + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + BaselineProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + shared_storage(shared_storage_) + { + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + tile_count_sum = this->tile_count(grid); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx < tile_count_sum) { + return true; + } + + do { + ++this->problem_idx; + + if (this->problem_idx >= this->params.problem_count) { + return false; + } + + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + + this->problem_tile_start = tile_count_sum; + tile_count_sum += this->tile_count(grid); + + } while (tile_count_sum <= this->tile_idx); + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} + + CUTLASS_DEVICE + cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { + int32_t macro_id = threadblock_id / ProblemSizeHelper::OffsetHelper::kThreadblockSkewRatio; + int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; + int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); + + if (FillModeC == cutlass::FillMode::kUpper) { + cutlass::swap(macro_row, macro_col); + } + + int32_t row = ProblemSizeHelper::OffsetHelper::macro_row_to_row(macro_row, threadblock_id); + int32_t col = ProblemSizeHelper::OffsetHelper::macro_col_to_col(macro_col, threadblock_id); + + return cutlass::gemm::GemmCoord(row, col, 0); + } +}; + +template +struct ProblemVisitorKernel { + struct SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct Params { + typename ProblemVisitor::Params problem_visitor_params; + int32_t* visited_problems_ptr; + int32_t* visited_tiles_ptr; + int32_t visits_per_block; + + Params(): + visited_problems_ptr(nullptr), + visited_tiles_ptr(nullptr), + visits_per_block(0) {} + + Params(typename ProblemVisitor::Params problem_visitor_params_, + int32_t* visited_problems_ptr_, + int32_t* visited_tiles_ptr_, + int32_t visits_per_block_): + problem_visitor_params(problem_visitor_params_), + visited_problems_ptr(visited_problems_ptr_), + visited_tiles_ptr(visited_tiles_ptr_), + visits_per_block(visits_per_block_) {} + }; + + CUTLASS_DEVICE + void operator()(const Params& params, SharedStorage &shared_storage) { + int32_t store_offset = params.visits_per_block * blockIdx.x; + ProblemVisitor problem_visitor(params.problem_visitor_params, + shared_storage.problem_visitor, + blockIdx.x); + + while (problem_visitor.next_tile()) { + cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + cutlass::gemm::GemmCoord tile_offset = problem_visitor.threadblock_offset(threadblock_idx); + + problem_visitor.advance(gridDim.x); + + // + // Early exit conditions + // 1) Out of range + // 2) Upper-triangular block in lower-triangular problem + // 3) Lower-triangular block in upper-triangular problem + // + + if (grid_shape.m() <= tile_offset.m() || + grid_shape.n() <= tile_offset.n()) { + continue; + } + + if (ProblemVisitor::kFillModeC == cutlass::FillMode::kLower && + (tile_offset.m() + 1) * ProblemVisitor::ThreadblockShape::kM <= tile_offset.n() * ProblemVisitor::ThreadblockShape::kN) { + continue; + } + + if (ProblemVisitor::kFillModeC == cutlass::FillMode::kUpper && + tile_offset.m() * ProblemVisitor::ThreadblockShape::kM >= (tile_offset.n() + 1) * ProblemVisitor::ThreadblockShape::kN) { + continue; + } + + if (threadIdx.x == 0) { + params.visited_problems_ptr[store_offset] = problem_idx; + params.visited_tiles_ptr[store_offset] = threadblock_idx; + ++store_offset; + } + } + } +}; + +template +struct ProblemVisitorRunner { + using BaseKernel = ProblemVisitorKernel; + using Params = typename BaseKernel::Params; + + Params params; + std::vector host_problem_sizes; + int32_t problem_count; + int32_t threadblock_count; + int32_t visits_per_block; + cutlass::DeviceAllocation visited_problems; + cutlass::DeviceAllocation visited_tiles; + cutlass::DeviceAllocation device_problem_sizes; + cutlass::DeviceAllocation workspace; + std::vector host_visited_problems; + std::vector host_visited_tiles; + + ProblemVisitorRunner(const std::vector& host_problem_sizes_, + int32_t threadblock_count_): + host_problem_sizes(host_problem_sizes_), + problem_count(int32_t(host_problem_sizes_.size())), + threadblock_count(threadblock_count_) {} + + /// Initializes GEMM state from arguments. + cutlass::Status initialize() { + size_t workspace_bytes = ProblemVisitor::get_workspace_size( + host_problem_sizes.data(), + problem_count, + threadblock_count); + + workspace.reset(workspace_bytes); + std::vector host_workspace(workspace_bytes); + + int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); + + ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, + threadblock_count, host_workspace.data()); + + workspace.copy_from_host(host_workspace.data(), workspace_bytes); + + device_problem_sizes.reset(problem_count); + device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); + + visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; + int32_t total_visits = visits_per_block * threadblock_count; + + visited_problems.reset(total_visits); + visited_tiles.reset(total_visits); + host_visited_problems.resize(total_visits); + host_visited_tiles.resize(total_visits); + + cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); + params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); + + return cutlass::Status::kSuccess; + } + + bool verify() { + // Sort by problem size and then by threadblock_idx + std::vector indices(host_visited_problems.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort(indices.begin(), indices.end(), + [&](int32_t i1, int32_t i2) { + if (host_visited_problems[i1] == host_visited_problems[i2]) { + return host_visited_tiles[i1] < host_visited_tiles[i2]; + } + return host_visited_problems[i1] < host_visited_problems[i2]; + }); + + int32_t idx = 0; + + // Skip any entries that were not visited + while (host_visited_problems[indices[idx]] == -1) { + ++idx; + } + + // Check that each problem visited has the tiles we expect + for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { + auto problem = host_problem_sizes[problem_idx]; + ProblemVisitor::possibly_transpose_problem(problem); + int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); + for (int i = 0; i < problem_tiles; ++i) { + EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); + EXPECT_EQ(i, host_visited_tiles[indices[idx]]); + ++idx; + } + } + + return true; + } + + bool run(bool skip_tile_check=false, cudaStream_t stream = nullptr) { + cutlass::Status status = initialize(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Initialization failed" << std::endl; + return false; + } + + dim3 grid(threadblock_count, 1, 1); + dim3 block(ProblemVisitor::kThreadCount, 1, 1); + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + cutlass::Kernel<<>>(params); + + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + visited_problems.copy_to_host(host_visited_problems.data()); + visited_tiles.copy_to_host(host_visited_tiles.data()); + + if (skip_tile_check) { + return true; + } + + return verify(); + } +}; + +template +struct TestbedGroupedRank2KScheduler { + + using BaselinePV = BaselineProblemVisitor, + ThreadblockShape, + PrefetchTileCount, + ThreadCount, + FillModeC>; + + // + // Data members + // + + // Whether to skip checking that the tiles are visited as expected. This is useful + // in cases where ThreadblockShape::kM != ThreadblockShape::kN, for which the grouped + // Rank2K scheduler may assign out-of-bounds tiles that will cause a threadblock to + // exit early, but which are difficult to detect in tests without reimplementing + // this functionality. + bool skip_tile_check; + uint32_t seed; + int problem_count; + int threadblock_count; + std::vector problem_sizes_host; + + // + // Methods + // + + TestbedGroupedRank2KScheduler(bool skip_tile_check_=false, uint32_t seed_ = 3080): + skip_tile_check(skip_tile_check_), seed(seed_) { srand(seed); } + + /// Initializes data structures + void initialize(int32_t scale_factor) { + + // + // Choose random problem sizes + // + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + int n = scale_factor * (rand() % 64) + 24; + + cutlass::gemm::GemmCoord problem( + n, + n, + scale_factor * (rand() % 64) + 24); + + problem_sizes_host.at(i) = problem; + } + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + using PV = cutlass::gemm::kernel::Rank2KGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + FillModeC>; + ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(runner.run(skip_tile_check)); + + // Check that this problem visitor visits the same problems and tiles as the baseline + EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); + EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + // Compare the next visitor with the baseline visitor + compare_visitors(baseline_runner); + + // Recurse to compare the next visitors + compare_visitors(baseline_runner); + } + + /// Executes the test on all scheduler modes + void run(int problem_count, int threadblock_count, int scale_factor=8) { + + this->problem_count = problem_count; + this->threadblock_count = threadblock_count; + + // Initialize the problem + initialize(scale_factor); + + // Run the baseline visitor to which we will compare all other visitors + ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(baseline_runner.run(skip_tile_check)); + + compare_visitors(baseline_runner); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..bda2704b517ea95052e2c2060b50712b686344f6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h @@ -0,0 +1,407 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped GEMM problem visitors +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/util/device_memory.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Use simple problem visitor as a baseline +template +struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { + using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + + struct SharedStorage {}; + + int32_t tile_count_sum; + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + BaselineProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + shared_storage(shared_storage_) + { + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + tile_count_sum = this->tile_count(grid); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx < tile_count_sum) { + return true; + } + + do { + ++this->problem_idx; + + if (this->problem_idx >= this->params.problem_count) { + return false; + } + + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + + this->problem_tile_start = tile_count_sum; + tile_count_sum += this->tile_count(grid); + + } while (tile_count_sum <= this->tile_idx); + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ProblemVisitorKernel { + struct SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct Params { + typename ProblemVisitor::Params problem_visitor_params; + int32_t* visited_problems_ptr; + int32_t* visited_tiles_ptr; + int32_t visits_per_block; + + Params(): + visited_problems_ptr(nullptr), + visited_tiles_ptr(nullptr), + visits_per_block(0) {} + + Params(typename ProblemVisitor::Params problem_visitor_params_, + int32_t* visited_problems_ptr_, + int32_t* visited_tiles_ptr_, + int32_t visits_per_block_): + problem_visitor_params(problem_visitor_params_), + visited_problems_ptr(visited_problems_ptr_), + visited_tiles_ptr(visited_tiles_ptr_), + visits_per_block(visits_per_block_) {} + }; + + CUTLASS_DEVICE + void operator()(const Params& params, SharedStorage &shared_storage) { + int32_t store_offset = params.visits_per_block * blockIdx.x; + ProblemVisitor problem_visitor(params.problem_visitor_params, + shared_storage.problem_visitor, + blockIdx.x); + + while (problem_visitor.next_tile()) { + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + if (threadIdx.x == 0) { + params.visited_problems_ptr[store_offset] = problem_idx; + params.visited_tiles_ptr[store_offset] = threadblock_idx; + ++store_offset; + } + problem_visitor.advance(gridDim.x); + } + } +}; + +template +struct ProblemVisitorRunner { + using BaseKernel = ProblemVisitorKernel; + using Params = typename BaseKernel::Params; + + Params params; + std::vector host_problem_sizes; + int32_t problem_count; + int32_t threadblock_count; + int32_t visits_per_block; + cutlass::DeviceAllocation visited_problems; + cutlass::DeviceAllocation visited_tiles; + cutlass::DeviceAllocation device_problem_sizes; + cutlass::DeviceAllocation workspace; + std::vector host_visited_problems; + std::vector host_visited_tiles; + + ProblemVisitorRunner(const std::vector& host_problem_sizes_, + int32_t threadblock_count_): + host_problem_sizes(host_problem_sizes_), + problem_count(int32_t(host_problem_sizes_.size())), + threadblock_count(threadblock_count_) {} + + /// Initializes GEMM state from arguments. + cutlass::Status initialize() { + size_t workspace_bytes = ProblemVisitor::get_workspace_size( + host_problem_sizes.data(), + problem_count, + threadblock_count); + + workspace.reset(workspace_bytes); + std::vector host_workspace(workspace_bytes); + + int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); + + ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, + threadblock_count, host_workspace.data()); + + workspace.copy_from_host(host_workspace.data(), workspace_bytes); + + device_problem_sizes.reset(problem_count); + device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); + + visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; + int32_t total_visits = visits_per_block * threadblock_count; + + visited_problems.reset(total_visits); + visited_tiles.reset(total_visits); + host_visited_problems.resize(total_visits); + host_visited_tiles.resize(total_visits); + + cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); + params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); + + return cutlass::Status::kSuccess; + } + + bool verify() { + // Sort by problem size and then by threadblock_idx + std::vector indices(host_visited_problems.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort(indices.begin(), indices.end(), + [&](int32_t i1, int32_t i2) { + if (host_visited_problems[i1] == host_visited_problems[i2]) { + return host_visited_tiles[i1] < host_visited_tiles[i2]; + } + return host_visited_problems[i1] < host_visited_problems[i2]; + }); + + int32_t idx = 0; + + // Skip any entries that were not visited + while (host_visited_problems[indices[idx]] == -1) { + ++idx; + } + + // Check that each problem visited has the tiles we expect + for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { + auto problem = host_problem_sizes[problem_idx]; + ProblemVisitor::possibly_transpose_problem(problem); + int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); + for (int i = 0; i < problem_tiles; ++i) { + EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); + EXPECT_EQ(i, host_visited_tiles[indices[idx]]); + ++idx; + } + } + + return true; + } + + bool run(cudaStream_t stream = nullptr) { + cutlass::Status status = initialize(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Initialization failed" << std::endl; + return false; + } + + dim3 grid(threadblock_count, 1, 1); + dim3 block(ProblemVisitor::kThreadCount, 1, 1); + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + cutlass::Kernel<<>>(params); + + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + visited_problems.copy_to_host(host_visited_problems.data()); + visited_tiles.copy_to_host(host_visited_tiles.data()); + + return verify(); + } +}; + +template +struct TestbedGroupedGemmScheduler { + + using PSHelper = cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper; + using BaselinePV = BaselineProblemVisitor; + + // + // Data members + // + uint32_t seed; + int problem_count; + int threadblock_count; + std::vector problem_sizes_host; + + // + // Methods + // + + TestbedGroupedGemmScheduler(uint32_t seed_ = 3080): + seed(seed_) { srand(seed); } + + /// Initializes data structures + void initialize(int32_t scale_factor) { + + // + // Choose random problem sizes + // + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + cutlass::gemm::GemmCoord problem( + scale_factor * (rand() % 64) + 24, + scale_factor * (rand() % 64) + 24, + scale_factor * (rand() % 64) + 24); + + problem_sizes_host.at(i) = problem; + } + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + using PV = cutlass::gemm::kernel::GemmGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + Transpose>; + ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(runner.run()); + + // Check that this problem visitor visits the same problems and tiles as the baseline + EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); + EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + // Compare the next visitor with the baseline visitor + compare_visitors(baseline_runner); + + // Recurse to compare the next visitors + compare_visitors(baseline_runner); + } + + /// Executes the test on all scheduler modes + void run(int problem_count, int threadblock_count, int scale_factor=8) { + + this->problem_count = problem_count; + this->threadblock_count = threadblock_count; + + // Initialize the problem + initialize(scale_factor); + + // Run the baseline visitor to which we will compare all other visitors + ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(baseline_runner.run()); + + compare_visitors(baseline_runner); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..2a5956000db8e8c05ea22538e58149998b03e3fc --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h @@ -0,0 +1,346 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct InterleavedTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + InterleavedTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Waives test if CUDA device is insufficient + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm::ElementA, + typename Gemm::LayoutA> tensor_A(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_C(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_D(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reorder_column( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); + + cutlass::reference::host::TensorCopy( + reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), + tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nB_reordered =\n" << tensor_B_reordered.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + int problem_size_n[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + int problem_size_k[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = run( + {m, n, k}, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..32452c30e05f64763a268195ae78138f26c09735 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +class TestbedPlanarComplex { +public: + + using ElementA = typename Gemm::ElementA; + using LayoutA = typename Gemm::LayoutA; + using ElementB = typename Gemm::ElementB; + using LayoutB = typename Gemm::LayoutB; + using ElementC = typename Gemm::ElementC; + using LayoutC = typename Gemm::LayoutC; + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + // + // Data members + // + + cutlass::gemm::GemmCoord problem_size; + cutlass::HostTensorPlanarComplex tensor_A; + cutlass::HostTensorPlanarComplex tensor_B; + cutlass::HostTensorPlanarComplex tensor_C; + cutlass::HostTensorPlanarComplex tensor_D; + cutlass::HostTensorPlanarComplex tensor_D_ref; + + // + // Methods + // + + TestbedPlanarComplex(cutlass::gemm::GemmCoord const & problem_size): problem_size(problem_size) { + + tensor_A.reset({problem_size.m(), problem_size.k()}); + tensor_B.reset({problem_size.k(), problem_size.n()}); + tensor_C.reset({problem_size.m(), problem_size.n()}); + tensor_D.reset({problem_size.m(), problem_size.n()}); + tensor_D_ref.reset({problem_size.m(), problem_size.n()}, false); + } + + void initialize() { + + uint64_t seed = 1073; + + int scope_max = 8; + int scope_min = -8; + + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed * 2019, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_C.host_view(), seed * 2020, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFill(tensor_D.host_view(), cutlass::complex()); + cutlass::reference::host::TensorFill(tensor_D_ref.host_view(), cutlass::complex()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + bool run( + cutlass::complex alpha = {1, 0}, + cutlass::complex beta = {0, 0}) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + initialize(); + + int batch_count = 1; + + ElementA *ptr_A = tensor_A.device_data(); + ElementB *ptr_B = tensor_B.device_data(); + ElementC *ptr_C = tensor_C.device_data(); + ElementC *ptr_D = tensor_D.device_data(); + + typename LayoutA::Stride::Index lda = tensor_A.layout().stride(0); + typename LayoutB::Stride::Index ldb = tensor_B.layout().stride(0); + typename LayoutC::Stride::Index ldc = tensor_C.layout().stride(0); + typename LayoutC::Stride::Index ldd = tensor_D.layout().stride(0); + + int64_t imag_stride_A = tensor_A.imaginary_stride(); + int64_t imag_stride_B = tensor_B.imaginary_stride(); + int64_t imag_stride_C = tensor_C.imaginary_stride(); + int64_t imag_stride_D = tensor_D.imaginary_stride(); + + // + // Launch device kernel + // + + Gemm gemm_op; + + typename Gemm::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + batch_count, + {alpha, beta}, + ptr_A, + ptr_A + imag_stride_A, + ptr_B, + ptr_B + imag_stride_B, + ptr_C, + ptr_C + imag_stride_C, + ptr_D, + ptr_D + imag_stride_D, + lda, + lda, + ldb, + ldb, + ldc, + ldc, + ldd, + ldd + }; + + cutlass::Status status = gemm_op(args); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + + cudaError_t error = cudaDeviceSynchronize(); + + tensor_D.sync_host(); + + // + // Compute reference + // + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + tensor_D_ref.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_D.host_view(), + tensor_D_ref.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("gemm_planar_complex.txt"); + + output + << "A:\n" << tensor_A.host_view() << "\n" + << "B:\n" << tensor_B.host_view() << "\n" + << "C:\n" << tensor_C.host_view() << "\n" + << "Reference:\n" + << tensor_D_ref.host_view() << "\n" + << "Computed:\n" + << tensor_D.host_view() << "\n"; + } + + return passed; + } +}; + +template +bool TestOneGemmPlanarComplex(cutlass::gemm::GemmCoord problem_size) { + + TestbedPlanarComplex testbed(problem_size); + + return testbed.run(); +} + +template +bool TestAllGemmPlanarComplex() { + + int M[] = { + 16, 64, 72, 144, 264, 520, + }; + + int N[] = { + 16, 64, 72, 144, 248, 264, 520 + }; + + int K[] = { + 8, 64, 72, 96, 264, 520 + }; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + cutlass::complex alpha_values[] = { + {ElementCompute(1.25), ElementCompute(-0.5)} + }; + + cutlass::complex beta_values[] = { + {ElementCompute(-2.25), ElementCompute(1.5)} + }; + + for (int m : M) { + for (int n : N) { + for (int k : K) { + + test::gemm::device::TestbedPlanarComplex testbed({m, n, k}); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = testbed.run(alpha, beta); + if (!passed) { + return false; + } + } + } + } + } + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..4d9f6743a45e5dc3a7b4ddd3e2a7b2abceffbb18 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h @@ -0,0 +1,641 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Rank 2k update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/rank_2k.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedRank2KUniversal { + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + using ElementCompute = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedRank2KUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Rank2K::kFillModeC, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Rank2K::kFillModeC, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the Rank2K workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.mk()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Rank2K::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Rank2K::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Rank2K::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a Rank2K + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + cutlass::reference::host::Rank2KComplex< + typename Rank2K::ElementA, typename Rank2K::LayoutA, + typename Rank2K::ElementB, typename Rank2K::LayoutB, + typename Rank2K::ElementC, typename Rank2K::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Rank2K::kTransformA, + tensor_B.host_ref(), + Rank2K::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Rank2K::Rank2Kkernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedRank2KUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the Rank2K operator + // + + typename Rank2K::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.n() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Rank2K rank2k_op; + + size_t workspace_size = Rank2K::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the Rank2K + // + + status = rank2k_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_Rank2k_device_" + << "fill_mode_c_" + << (Rank2K::kFillModeC == cutlass::FillMode::kLower ? "lower_" : + (Rank2K::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Rank2K::ThreadblockShape::kM << "x" + << Rank2K::ThreadblockShape::kN << "x" + << Rank2K::ThreadblockShape::kK << "_" + << Rank2K::WarpShape::kM << "x" + << Rank2K::WarpShape::kN << "x" + << Rank2K::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestRank2kUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedRank2KUniversal testbed; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllRank2KUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Rank2K::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 3.25 + }; + + double problem_beta[] = { + 0.0, 2.15 + }; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { + // continue; + //} + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +template +bool TestAllRank2KHermitianUniversal() { + bool passed = true; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Rank2K::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + /* Complex alpha for HER2K */ + ElementAccumulator problem_alpha[] = { + {1.0}, + {1.25, 3.25}, + {-0.25, -2.25} + }; + + ElementAccumulator problem_beta[] = { + 0.0, -2.25 + }; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { + // continue; + //} + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + alpha, + beta + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..cb46528a049ae1254d0492b6235821210e47b957 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h @@ -0,0 +1,511 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Rank 2k update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/rank_k_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedRank2KUniversal { + + using ElementA = typename RankK::ElementA; + using ElementC = typename RankK::ElementC; + using ElementAccumulator = typename RankK::ElementAccumulator; + using ElementCompute = typename RankK::RankKkernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedRank2KUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, RankK::kFillModeC, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, RankK::kFillModeC, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the RankK workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename RankK::ElementA(1); + tensor_C.host_view().at({0, 0}) = typename RankK::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a RankK + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + cutlass::reference::host::Rank2KComplex< + typename RankK::ElementA, typename RankK::LayoutA, + typename RankK::ElementC, typename RankK::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + RankK::kTransformA, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0), + RankK::kFillModeC, + RankK::kBlasMode + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename RankK::RankKkernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedRankKUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the RankK operator + // + + typename RankK::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + RankK rank2k_op; + + size_t workspace_size = RankK::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the RankK + // + + status = rank2k_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_RankK_device_" + << "fill_mode_c_" + << (RankK::kFillModeC == cutlass::FillMode::kLower ? "lower_" : + (RankK::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << RankK::ThreadblockShape::kM << "x" + << RankK::ThreadblockShape::kN << "x" + << RankK::ThreadblockShape::kK << "_" + << RankK::WarpShape::kM << "x" + << RankK::WarpShape::kN << "x" + << RankK::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestRank2kUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedRank2KUniversal testbed; + + using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllRankKUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + int const kAlignmentN = 128 / kMinimumOperandElementSize; + int const kAlignmentK = 128 / kMinimumOperandElementSize; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + RankK::ThreadblockShape::kK * RankK::kStages - kAlignmentK, + RankK::ThreadblockShape::kK * RankK::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 2.0 + }; + + + using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h new file mode 100644 index 0000000000000000000000000000000000000000..0a01a6a32ee2db84f2e890059423cd6b8477f766 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/core_io.h" + +#include "testbed.h" + + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// List of Gemm internal paramters this testbed supports user verification +// +enum class ParameterID { + + // Threadblock-level parameters + kSmemASize, + kSmemBSize, + + // Warp-level parameters + kWarpFragmentASize, + kWarpFragmentBSize, + kWarpFragmentCSize, + kInvalid +}; + +struct Reference { + ParameterID parameter_id; + + union { + int value; + + struct { + int m, n, k; + } gemm_shape; + + struct { + int row, column; + } matrix_shape; + }; + + std::string error_msg; + + Reference( + ParameterID parameter_id_, + int value_=-1, + std::string const &error_msg_="") : parameter_id(parameter_id_), value(value_), error_msg(error_msg_) {} +}; + + +template +struct TestbedSanity { + + // + // Type definitions (All Gemm types top down) + // + + // Unpacking Gemm types in the following order + // Kernel-level > Threadblock-level > Warp-level > Instruction-level + + // kernel-level cutlass Gemm + using GemmKernel = typename Gemm::GemmKernel; + + // + // Threadblock-level gemm types + // + using MmaThreadBlock = typename GemmKernel::Mma; + + // Threadblock-level gemm shape covering one stage + using ThreadblockShape = typename MmaThreadBlock::Shape; + + // Shared memory size covering all stages + using SmemShapeA = typename MmaThreadBlock::Base::SharedStorage::ShapeA; + using SmemPaddingA = typename MmaThreadBlock::Policy::SmemPaddingA; + using SmemShapeB = typename MmaThreadBlock::Base::SharedStorage::ShapeB; + using SmemPaddingB = typename MmaThreadBlock::Policy::SmemPaddingB; + + + /// Number of stages + static int const kStages = MmaThreadBlock::Base::kStages; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = MmaThreadBlock::kWarpGemmIterations; + + + // + // Warp-level gemm types + // + + // Warp-level gemm operator + using MmaWarp = typename MmaThreadBlock::Operator; + + // Warp-level gemm shape covering all kgroups + using WarpShape = typename MmaWarp::Shape; + + // Warp-level framents holding operands A & B operand and destination C + using WarpFragmentA = typename MmaWarp::FragmentA; + using WarpFragmentB = typename MmaWarp::FragmentB; + using WarpFragmentC = typename MmaWarp::FragmentC; + + // + // Instruction-level gemm types + // + + // Instruction-level gemm operator + using MmaInstruction = typename MmaWarp::Policy::Operator; + + // Instruction shape + using InstructionShape = typename MmaInstruction::Shape; + + // Instruction-level framents holding operands A & B operand and destination C + using InstructionFragmentA = typename MmaInstruction::FragmentA; + using InstructionFragmentB = typename MmaInstruction::FragmentB; + using InstructionFragmentC = typename MmaInstruction::FragmentC; + + // + // Testbed types + // + + // Vector of values holding user provided reference + using ReferenceVector = std::vector; + + // + // Data members + // + ReferenceVector references; + + // + // Methods + // + + TestbedSanity(ReferenceVector const &references_ = ReferenceVector()) : references(references_){ } + + // verify all parameter in ReferenceVector + bool verify() { + for(auto ref : references) + verify_parameter(ref); + return true; + } + + // verify parameter of type Reference + void verify_parameter(Reference const& ref) { + switch(ref.parameter_id) { + case ParameterID::kWarpFragmentASize : EXPECT_TRUE(WarpFragmentA::kElements == ref.value) << *this; break; + case ParameterID::kWarpFragmentBSize : EXPECT_TRUE(WarpFragmentB::kElements == ref.value) << *this; break; + case ParameterID::kWarpFragmentCSize : EXPECT_TRUE(WarpFragmentC::kElements == ref.value) << *this; break; + } + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Overload output operators for TesbedSanity +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +std::ostream & operator<<(std::ostream &out, TestbedSanity const &test) { + + + out << "Gemm internal parameters" << std::endl + << " Threadblock-level parameters:" << std::endl + << " ThreadblockShape = " << typename TestbedSanity::ThreadblockShape() << std::endl + << " kStages = " << TestbedSanity::kStages << std::endl + << " kWarpGemmIterations = "<< TestbedSanity::kWarpGemmIterations << std::endl + <<" Shared memory sizes:" << std::endl + <<" SmemPaddingA = " << typename TestbedSanity::SmemPaddingA() << std::endl + <<" SmemPaddingB = " << typename TestbedSanity::SmemPaddingB() << std::endl + <<" SmemShapeA = " << typename TestbedSanity::SmemShapeA() << std::endl + <<" SmemShapeB = " << typename TestbedSanity::SmemShapeB() << std::endl + <<" Warp-level parameters" << std::endl + <<" WarpShape = " << typename TestbedSanity::WarpShape() << std::endl + <<" Fragment sizes:" << std::endl + <<" WarpFragmentA::kElements = " << TestbedSanity::WarpFragmentA::kElements << std::endl + <<" WarpFragmentB::kElements = " << TestbedSanity::WarpFragmentB::kElements << std::endl + <<" WarpFragmentC::kElements = " << TestbedSanity::WarpFragmentC::kElements << std::endl + <<" Instruction-level parameters" << std::endl + <<" InstructionShape = " << typename TestbedSanity::InstructionShape() << std::endl + <<" Fragment sizes:" << std::endl + <<" InstructionFragmentA::kElements = " << TestbedSanity::InstructionFragmentA::kElements << std::endl + <<" InstructionFragmentB::kElements = " << TestbedSanity::InstructionFragmentB::kElements << std::endl + <<" InstructionFragmentC::kElements = " << TestbedSanity::InstructionFragmentC::kElements << std::endl; + + return out; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h new file mode 100644 index 0000000000000000000000000000000000000000..a95bf996bac337b44da616dc9fbf9c9bdb2a625c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h @@ -0,0 +1,487 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + + Testbed for sparse operations not to be released for CUDA 11.0 GA. Expected release is 11.1. +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SparseTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + static int const kSparse = Gemm::GemmKernel::kSparse; + static int const kMetaSizeInBits = Gemm::GemmKernel::kMetaSizeInBits; + static int const kMaxID2 = Gemm::GemmKernel::kMaxID2; + static int const kElementsPerElementE = Gemm::GemmKernel::kElementsPerElementE; + + using ElementE = typename Gemm::GemmKernel::ElementE; + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = typename Gemm::GemmKernel::LayoutE; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_E; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_uncompressed; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_E_reordered; + + // + // Methods + // + + SparseTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), + init_B(init_B_), + init_C(init_C_), + init_E(init_E_), + seed(seed_) {} + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); + tensor_A_uncompressed.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + tensor_E.resize(cutlass::make_Coord( + problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + tensor_E_reordered.resize(cutlass::make_Coord( + problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_E.host_view(), seed, kMetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (kMaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(tensor_E.host_view(), + (ElementE)(content)); + } else { + EXPECT_TRUE(false); + } + + cutlass::reorder_meta(tensor_E_reordered.host_ref(), tensor_E.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / kSparse / kElementsPerElementE}); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_E_reordered.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nE =\n" << tensor_E.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), + tensor_E.host_ref(), problem_size.m(), problem_size.k()); + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A_uncompressed.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + split_k_slices, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_E_reordered.device_data(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_E_reordered.layout().stride(0) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + // This failure is likely due to insufficient device capabilities. Waive the test. + if (status != cutlass::Status::kSuccess) { + return true; + } + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << ", beta: " << beta << ", m: " << problem_size.m() << ", n: " << problem_size.n() << ", k:" < +bool TestAllSparseGemm() { + bool passed = true; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the reordering of operand E + int const kAlignmentM = std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), + kMinimumOperandElementSize); + + int const kAlignmentN = 128 / kMinimumOperandElementSize; + + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; + + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + + int problem_size_k[] = {Gemm::ThreadblockShape::kK * 8}; + + int split_k_slices[] = { + 1, 2 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + SparseTestbed testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + cutlass::gemm::GemmCoord problem_size(m, n, k); + + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h new file mode 100644 index 0000000000000000000000000000000000000000..8fa4a85505316d08f1d050702b78448f8fae8565 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "testbed.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedSplitK : public Testbed { + + using Base = Testbed; + + using ElementCompute = typename Base::ElementCompute; + + // + // Methods + // + + TestbedSplitK( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + Base(init_A_, init_B_, init_C_, seed_) { } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + this->tensor_A.device_ref(), + this->tensor_B.device_ref(), + this->tensor_C.device_ref(), + this->tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + return this->verify(problem_size, alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmSplitK() { + bool passed = true; + + cutlass::gemm::GemmCoord problem_sizes[] = { + {8, 8, 2048}, + {8, 8, 2056}, + {264, 72, 520}, + {264, 520, 120}, + {264, 520, 264} + }; + + int split_k_slices[] = { + 1, 2, 4, 5, 7 + }; + + double problem_alpha[] = { + 0.5 + }; + + double problem_beta[] = { + 2.0 + }; + + using Testbed = TestbedSplitK; + using ElementCompute = typename Testbed::ElementCompute; + + Testbed testbed; + + for (auto problem_size : problem_sizes) { + for (int split_k_count : split_k_slices) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = testbed.run( + problem_size, + split_k_count, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + std::cout << "Failed on size " << problem_size << " with split_k_count " << split_k_count << std::endl; + return false; + } + } + } + } + } + + EXPECT_TRUE(passed); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..b7a57f7eb0ca73c23460e5a9ce1301061c2cc286 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h @@ -0,0 +1,592 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Symm update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/symm.h" +#include "cutlass/util/reference/host/symm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedSymmUniversal { + + using ElementA = typename Symm::ElementA; + using ElementB = typename Symm::ElementB; + using ElementC = typename Symm::ElementC; + using ElementAccumulator = typename Symm::ElementAccumulator; + using ElementCompute = typename Symm::SymmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedSymmUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Symm::kFillModeA, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Symm::kFillModeA, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the Symm workspace + // + + if (Symm::kSideModeA == cutlass::SideMode::kLeft) { + tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); + } + else if (Symm::kSideModeA == cutlass::SideMode::kRight) { + tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); + } + + tensor_B.resize(problem_size.mn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Symm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Symm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Symm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a Symm + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + using HostReference = typename cutlass::platform::conditional< + (cutlass::platform::is_same + >::value || + cutlass::platform::is_same + >::value + ), + cutlass::reference::host::SymmComplex< + typename Symm::ElementA, typename Symm::LayoutA, + Symm::kSideModeA, Symm::kFillModeA, + typename Symm::ElementB, typename Symm::LayoutB, + typename Symm::ElementC, typename Symm::LayoutC, + ElementCompute, + ElementAccumulator, + Symm::kBlasMode>, + cutlass::reference::host::Symm< + typename Symm::ElementA, typename Symm::LayoutA, + Symm::kSideModeA, Symm::kFillModeA, + typename Symm::ElementB, typename Symm::LayoutB, + typename Symm::ElementC, typename Symm::LayoutC, + ElementCompute, + ElementAccumulator> + >::type; + + + HostReference reference_symm; + + reference_symm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Symm::SymmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedSymmUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the Symm operator + // + + int batch_stride_A; + if (Symm::kSideModeA == cutlass::SideMode::kLeft) + batch_stride_A = problem_size.m()*problem_size.m(); + if (Symm::kSideModeA == cutlass::SideMode::kRight) + batch_stride_A = problem_size.n()*problem_size.n(); + + typename Symm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + batch_stride_A, + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Symm symm_op; + + size_t workspace_size = Symm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = symm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the Symm + // + + status = symm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_" + << (Symm::kBlasMode == cutlass::BlasMode::kSymmetric ? "symm_" : "hemm_" ) + << "device_" + << "fill_mode_a_" + << (Symm::kSideModeA == cutlass::SideMode::kLeft ? "leftside_" : + (Symm::kSideModeA == cutlass::SideMode::kRight ? "rightside_" : "invalid_")) + << (Symm::kFillModeA == cutlass::FillMode::kLower ? "lower_" : + (Symm::kFillModeA == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Symm::ThreadblockShape::kM << "x" + << Symm::ThreadblockShape::kN << "x" + << Symm::ThreadblockShape::kK << "_" + << Symm::WarpShape::kM << "x" + << Symm::WarpShape::kN << "x" + << Symm::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "alpha: " << ElementCompute(alpha) << "\n" + << "beta: " << ElementCompute(beta) << "\n" + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestsymmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedSymmUniversal testbed; + + using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllSymmUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Symm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentK, + Symm::ThreadblockShape::kK * Symm::kStages - kAlignmentK, + Symm::ThreadblockShape::kK * Symm::kStages * 3 - kAlignmentK + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 3.0 + }; + + double problem_beta[] = { + 0, 2.0 + }; + + + using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + int k = 0; + if (Symm::kSideModeA == cutlass::SideMode::kLeft) + k = m; + else if (Symm::kSideModeA == cutlass::SideMode::kRight) + k = n; + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + #if 0 + // skip very small K problems + if (k / batch_count < 2 * Symm::ThreadblockShape::kK) { + continue; + } + #endif + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedSymmUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..b30acfed6bba547986efd3afa8eb829be2a255e4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h @@ -0,0 +1,606 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide TRMM interface + + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/trmm.h" +#include "cutlass/util/reference/host/trmm_complex.h" +#include "cutlass/core_io.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedTrmmUniversal { + + using ElementA = typename Trmm::ElementA; + using ElementB = typename Trmm::ElementB; + using ElementC = typename Trmm::ElementC; + using ElementAccumulator = typename Trmm::ElementAccumulator; + using ElementCompute = typename Trmm::TrmmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_D; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedTrmmUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_D(init_D_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Trmm::kFillMode, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Trmm::kFillMode, 0, 0.5, mantissa_in_bits); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Helper to initialize a tensor view (pad diagonal fill with zeros for up to alignment on wrong side of diagonal) + template + bool initialize_pad_diagonal_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int alignment) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillPadDiagonalRandomUniform( + view, seed, Trmm::kFillMode, scope_max, scope_min, 0, alignment); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + EXPECT_TRUE(false) << "Gaussian distribution for pad diagonal not implemented"; + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the TRMM workspace + // + + if (Trmm::kSideMode == cutlass::SideMode::kLeft) { + tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); + } + else if (Trmm::kSideMode == cutlass::SideMode::kRight) { + tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); + } + + tensor_B.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + //EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2017)); + //EXPECT_TRUE(initialize_pad_diagonal_tensor(tensor_A.host_view(), init_A, seed + 2017, Trmm::kAlignmentA)); + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2017, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2019, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Trmm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Trmm::ElementB(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_D.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a TRMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha) { + + // + // Verify + // + + using HostReference = typename cutlass::platform::conditional< + (cutlass::platform::is_same + >::value || + cutlass::platform::is_same + >::value + ), + cutlass::reference::host::TrmmComplex< + typename Trmm::ElementA, typename Trmm::LayoutA, + Trmm::kTransformA, + Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, + typename Trmm::ElementB, typename Trmm::LayoutB, + Trmm::kTransformB, + typename Trmm::ElementC, typename Trmm::LayoutC, + ElementCompute, + ElementAccumulator>, + cutlass::reference::host::Trmm< + typename Trmm::ElementA, typename Trmm::LayoutA, + Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, + typename Trmm::ElementB, typename Trmm::LayoutB, + typename Trmm::ElementC, typename Trmm::LayoutC, + ElementCompute, + ElementAccumulator> + >::type; + + + HostReference reference_trmm; + + reference_trmm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Trmm::TrmmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedTrmmUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the TRMM operator + // + + int batch_stride_A; + if (Trmm::kSideMode == cutlass::SideMode::kLeft) + batch_stride_A = problem_size.m()*problem_size.m(); + if (Trmm::kSideMode == cutlass::SideMode::kRight) + batch_stride_A = problem_size.n()*problem_size.n(); + + typename Trmm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_D.device_data(), + batch_stride_A, + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Trmm trmm_op; + + size_t workspace_size = Trmm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = trmm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the TRMM + // + + status = trmm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, alpha); + + if (!passed) { + std::stringstream fname; + + fname << "error_Trmm_device_" + << "fill_mode_" + << (Trmm::kFillMode == cutlass::FillMode::kLower ? "lower_" : + (Trmm::kFillMode == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "side_mode_" + << (Trmm::kSideMode == cutlass::SideMode::kLeft ? "left_" : + (Trmm::kSideMode == cutlass::SideMode::kRight ? "right_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Trmm::ThreadblockShape::kM << "x" + << Trmm::ThreadblockShape::kN << "x" + << Trmm::ThreadblockShape::kK << "_" + << Trmm::WarpShape::kM << "x" + << Trmm::WarpShape::kN << "x" + << Trmm::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestTrmmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0) { + + bool passed = true; + + TestbedTrmmUniversal testbed; + + using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha) + ); + + return passed; +} + +template +bool TestAllTrmmUniversal() { + bool passed = true; + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Trmm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentK, + Trmm::ThreadblockShape::kK * Trmm::kStages - kAlignmentK, + Trmm::ThreadblockShape::kK * Trmm::kStages * 3 - kAlignmentK + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 2.0 + }; + + using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int batch_count : batch_counts) { + for (auto alpha : problem_alpha) { + + int k = 0; + if (Trmm::kSideMode == cutlass::SideMode::kLeft) + k = m; + else if (Trmm::kSideMode == cutlass::SideMode::kRight) + k = n; + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + +#if 0 + // skip very small K problems + if (k / batch_count < 2 * Trmm::ThreadblockShape::kK) { + continue; + } +#endif + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedTrmmUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..00368a5e8eebc128719f64069583010c83dc0c1f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h @@ -0,0 +1,553 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedUniversal { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + bool is_unsigned_int = std::numeric_limits::is_integer && !std::numeric_limits::is_signed; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = is_unsigned_int ? 2 : 1; + scope_min = is_unsigned_int ? 0 : -1; + } else if (bits_output == 16) { + constexpr auto u8_bf16 = + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value) || + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value); + scope_max = is_unsigned_int ? 10 : (u8_bf16 ? 3 : 5); + scope_min = is_unsigned_int ? 0 : (u8_bf16 ? -3 : -5); + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<2> origin(0); + tensor_A.host_view().at(origin) = typename Gemm::ElementA(1); + tensor_B.host_view().at(origin) = typename Gemm::ElementB(1); + tensor_C.host_view().at(origin) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + /* + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + if (Relu) { + for (int i = 0; i < problem_size.m(); ++i) { + for (int j = 0; j < problem_size.n(); ++j) { + reference_D.at(cutlass::MatrixCoord(i, j)) = + ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) + ? (typename Gemm::ElementC)0 + : reference_D.at(cutlass::MatrixCoord(i, j)); + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { +/* + std::cout << "\n-----------------------\n"; + std::cout << "mode: " << (int) mode << "\n"; + std::cout << "problem size: " << problem_size << "\n"; + std::cout << "batch_count: " << batch_count << "\n"; + std::cout << "alpha: " << alpha << "\n"; + std::cout << "beta: " << beta << "\n"; + std::cout << "-----------------------\n\n"; +*/ + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedUniversal testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllGemmUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentM, 512 - 3*kAlignmentM + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1, 2, 3, 5, 7 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + if (k / batch_count < 2 * Gemm::ThreadblockShape::kK) { + continue; + } + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + } + + /* + // large problem with high coverage + for (int split_k_slices = 1; split_k_slices <= 3; ++split_k_slices) { + TestbedUniversal testbed; + + cutlass::gemm::GemmCoord problem_size(72, 56, 8192); + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + split_k_slices, + cutlass::from_real(1.0), + cutlass::from_real(2.0) + ); + + if (!passed) { + break; + } + } + */ + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..89ac33a1028061515d08d50fdb6cce7833ae88ce --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h @@ -0,0 +1,53 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +inline char const *to_string(cutlass::Status status) { + + switch (status) { + case cutlass::Status::kSuccess: return "kSuccess"; + case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; + case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; + case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; + case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; + case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; + case cutlass::Status::kErrorInternal: return "kErrorInternal"; + case cutlass::Status::kInvalid: return "kInvalid"; + default: break; + } + return "invalid"; +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h new file mode 100644 index 0000000000000000000000000000000000000000..8b5588f57c40c4e8f8d06adfa9f1e673350fb5e5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h @@ -0,0 +1,609 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed for running device-level GEMMs with absolute maximum calculation and scaling +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" +#include "testbed_sparse.h" +#include "testbed_utils.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename GemmTestbed, + template class ActivationFunctor +> +struct TestbedWithAmax { + + static_assert(std::is_same_v> || std::is_same_v>); + static constexpr bool IsSparseTestbed = std::is_same_v>; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + using ElementScalingFactor = typename Gemm::EpilogueOutputOp::ElementScalingFactor; + using ElementAbsmax = typename Gemm::EpilogueOutputOp::ElementAbsmax; + + static bool const kScaleAux = Gemm::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; + static bool const kScaleOutput = Gemm::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; + bool doScaleA; + bool doScaleB; + bool doScaleC; + + GemmTestbed underlying_testbed; + + cutlass::HostTensor tensor_Aux; + cutlass::HostTensor tensor_Vector; + cutlass::HostTensor tmp_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // + // Methods + // + + TestbedWithAmax( + bool scaleA = true, + bool scaleB = true, + bool scaleC = true, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform + ): + doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC), + underlying_testbed(init_A_, init_B_, init_C_) { } + + /// Helper to initialize scaling factors + template + bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { + cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + underlying_testbed.initialize(problem_size); + + tensor_Vector.resize({1, problem_size.n()}); + reference_D.resize(problem_size.mn(), false); + tmp_D.resize(problem_size.mn(), false); + + EXPECT_TRUE( + underlying_testbed.initialize_tensor(tensor_Vector.host_view(), underlying_testbed.init_C, underlying_testbed.seed + 2020) + ); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<2> origin(0); + tensor_Vector.host_view().at(origin) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), underlying_testbed.tensor_C.host_view()); + + tensor_Vector.sync_device(); + + int scale_bits = 2; + if (doScaleA) { + scale_A.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), underlying_testbed.seed + 2021, scale_bits)); + scale_A.sync_device(); + } + + if (doScaleB) { + scale_B.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), underlying_testbed.seed + 2022, scale_bits)); + scale_B.sync_device(); + } + + if (doScaleC) { + scale_C.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), underlying_testbed.seed + 2023, scale_bits)); + scale_C.sync_device(); + } + + if (kScaleOutput) { + scale_D.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), underlying_testbed.seed + 2024, scale_bits)); + scale_D.sync_device(); + + abs_max_D.resize({1, 1}); + cutlass::reference::host::TensorFill(abs_max_D.host_view()); + abs_max_D.sync_device(); + + reference_abs_max_D.resize({1, 1}); + } + + if (kScaleAux) { + tensor_Aux.resize(problem_size.mn()); + cutlass::reference::host::TensorFill(tensor_Aux.host_view()); + tensor_Aux.sync_device(); + + scale_Aux.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), underlying_testbed.seed + 2025, scale_bits)); + scale_Aux.sync_device(); + + abs_max_Aux.resize({1, 1}); + cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); + abs_max_Aux.sync_device(); + + reference_Aux.resize(problem_size.mn(), false); + reference_abs_max_Aux.resize({1, 1}); + } + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + underlying_testbed.tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), underlying_testbed.tensor_D.host_view()); + if (!passed) { + std::cout << "Comparison of D failed" << std::endl; + } + + if (kScaleAux) { + tensor_Aux.sync_host(); + abs_max_Aux.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); + if (!cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view())) { + passed = false; + std::cout << "Comparison of Aux failed" << std::endl; + } + if (!cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view())) { + passed = false; + std::cout << "Comparison of Aux absmax failed" << std::endl; + } + } + + if (kScaleOutput) { + abs_max_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); + if (!cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view())) { + passed = false; + std::cout << "Comparison of D absmax failed" << std::endl; + } + } + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + std::ofstream file("testbed_with_amax_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << underlying_testbed.tensor_A.host_view() + << "\nB =\n" << underlying_testbed.tensor_B.host_view() + << "\nC =\n" << underlying_testbed.tensor_C.host_view() + << "\nVector =\n" << tensor_Vector.host_view() + << "\nScaleA = " << scale_A.host_view() + << "\nScaleB = " << scale_B.host_view() + << "\nScaleC = " << scale_C.host_view() + << "\nScaleD = " << scale_D.host_view() + << "\nScaleAux = " << scale_Aux.host_view() + << "\n\nReference D =\n" << reference_D.host_view() + << "\nComputed D =\n" << underlying_testbed.tensor_D.host_view(); + if (kScaleAux) { + file + << "\n\nReference Aux =\n" << reference_Aux.host_view() + << "\nComputed Aux =\n" << tensor_Aux.host_view() + << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() + << "\nComputed Absmax Aux = " << abs_max_Aux.host_view(); + } + if (kScaleOutput) { + file + << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() + << "\nComputed Absmax D = " << abs_max_D.host_view(); + } + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + cutlass::Coord<2> origin(0); + ElementCompute scaled_alpha = alpha; + if (doScaleA) { + scaled_alpha *= scale_A.host_view().at(origin); + } + if (doScaleB) { + scaled_alpha *= scale_B.host_view().at(origin); + } + + ElementCompute scaled_beta = beta; + if (doScaleC) { + scaled_beta *= scale_C.host_view().at(origin); + } + + // + // Verify + // + + auto ref_tA = [&](){ + if constexpr (IsSparseTestbed) { + cutlass::uncompress( + underlying_testbed.tensor_A_uncompressed.host_ref(), + underlying_testbed.tensor_A.host_ref(), + underlying_testbed.tensor_E.host_ref(), + problem_size.m(), + problem_size.k() + ); + return underlying_testbed.tensor_A_uncompressed.host_ref(); + } + else { + return underlying_testbed.tensor_A.host_ref(); + } + }(); + + // Run reference kernel with ElementOutput of type ElementAccumulator + // so that we can compute the absmax epilogue on data that is of type + // ElementAccumulator (which is what the GEMM we are testing will do). + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator, ElementAccumulator + >( + problem_size, + scaled_alpha, + ref_tA, + Gemm::kTransformA, + underlying_testbed.tensor_B.host_ref(), + Gemm::kTransformB, + scaled_beta, + underlying_testbed.tensor_C.host_ref(), + tmp_D.host_ref(), + ElementAccumulator(0) + ); + + ElementCompute tmp_abs_max_Aux(0.); + ElementCompute tmp_abs_max_D(0.); + + cutlass::NumericConverter cvt_c_to_compute; + cutlass::NumericConverter cvt_accum_to_compute; + cutlass::NumericConverter cvt_compute_to_absmax; + cutlass::NumericConverter cvt_compute_to_d; + cutlass::NumericConverter cvt_compute_to_aux; + + cutlass::absolute_value_op abs; + cutlass::maximum_with_nan_propogation max; + ActivationFunctor act; + + ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); + + for (int m = 0; m < problem_size.m(); ++m) { + for (int n = 0; n < problem_size.n(); ++n) { + ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({m, n})); + ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, n})); + ElementCompute aux = intermediate + bias; + ElementCompute d = act(aux); + tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); + tmp_abs_max_D = max(abs(d), tmp_abs_max_D); + reference_D.host_view().at({m, n}) = cvt_compute_to_d(d * d_scale); + + if (kScaleAux) { + reference_Aux.host_view().at({m, n}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); + } + } + } + + if (kScaleAux) { + reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux); + } + + if (kScaleOutput) { + reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D); + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + return underlying_testbed.sufficient(); + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta}; + typename Gemm::EpilogueOutputOp::Params epilogue_params{ + activation_params, + scale_A.device_data(), + scale_B.device_data(), + scale_C.device_data(), + scale_D.device_data(), + scale_Aux.device_data(), + abs_max_Aux.device_data(), + abs_max_D.device_data() + }; + + auto arguments = [&]() { + if constexpr (IsSparseTestbed) { + return typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + batch_count, + epilogue_params, + underlying_testbed.tensor_A.device_data(), + underlying_testbed.tensor_B.device_data(), + underlying_testbed.tensor_C.device_data(), + underlying_testbed.tensor_D.device_data(), + underlying_testbed.tensor_E_reordered.device_data(), + tensor_Aux.device_data(), + tensor_Vector.device_data(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + underlying_testbed.tensor_A.layout().stride(0), + underlying_testbed.tensor_B.layout().stride(0), + underlying_testbed.tensor_C.layout().stride(0), + underlying_testbed.tensor_D.layout().stride(0), + underlying_testbed.tensor_E_reordered.layout().stride(0), + tensor_Aux.layout().stride(0), + 0 // stride vector + }; + } + else { + return typename Gemm::Arguments{ + mode, + problem_size, + batch_count, + epilogue_params, + underlying_testbed.tensor_A.device_data(), + underlying_testbed.tensor_B.device_data(), + underlying_testbed.tensor_C.device_data(), + underlying_testbed.tensor_D.device_data(), + tensor_Aux.device_data(), + tensor_Vector.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + 0, // stride vector + underlying_testbed.tensor_A.layout().stride(0), + underlying_testbed.tensor_B.layout().stride(0), + underlying_testbed.tensor_C.layout().stride(0), + underlying_testbed.tensor_D.layout().stride(0), + (int64_t)0 // Leading dimension of vector. This must be 0 + }; + } + }(); + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + cudaError_t cuda_error = cudaDeviceSynchronize(); + EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename GemmTestbed, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAllGemmWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) { + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int constexpr kAlignmentM = [&]() { + if constexpr (std::is_same_v>) { + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the reordering of operand E + return std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), + kMinimumOperandElementSize); + } + else { + return 128 / kMinimumOperandElementSize; + } + }(); + + int const kAlignmentN = 128 / kMinimumOperandElementSize; + + int M_problems[] = {kAlignmentM, 128 + 32}; + int N_problems[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + int K_problems[] = {Gemm::ThreadblockShape::kK * 2}; + double alpha_problems[] = {1.}; + double beta_problems[] = {0.}; + int split_k_slices[] = { + 1, 2 + }; + + bool passed = true; + + for (int M : M_problems) { + for (int N : N_problems) { + for (int K : K_problems) { + for (int split_k : split_k_slices) { + if (cutlass::sizeof_bits_v <= 8 && split_k > 1) { + // Don't test split-K with FP8 output. The kernel being tested will writie partial accumulations + // for different splits to global memory in FP8, while the reference kernel will not. This leads + // to mismatches that are difficult to capture without a permissive relative equality check threshold. + continue; + } + + for (double alpha : alpha_problems) { + for (double beta : beta_problems) { + TestbedWithAmax testbed(scaleA, scaleB, scaleC); + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + EXPECT_TRUE(passed) + << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta << ", split_k:" << split_k; + + if (!passed) { + + return passed; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h new file mode 100644 index 0000000000000000000000000000000000000000..8e939f9710403a5f5c3fd8c61e34c4e8021ff423 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h @@ -0,0 +1,358 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/core_io.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "cutlass/gemm/kernel/default_gemv.h" +#include "cutlass/gemm/kernel/gemv_batched_strided.h" + +namespace test { +namespace gemm { +namespace kernel { + +template +void batched_gemv_kernel_test(cutlass::gemm::BatchedGemmCoord problem_size, + ElementCD_ alpha = ElementCD_(1), + ElementCD_ beta = ElementCD_(0), + bool perf_test = false, + int perf_test_iter = 1) +{ + using ThreadBlockShape = ThreadBlockShape_; + using ThreadShape = ThreadShape_; + using ElementA = ElementAB_; + using LayoutA = LayoutA_; + using ElementB = ElementAB_; + using LayoutB = LayoutB_; + using ElementAccumulator = ElementCD_; + using ElementCD = ElementCD_; + using LayoutCD = LayoutCD_; + + using GemvKernel = cutlass::gemm::kernel::DefaultGemv; + + using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv; + using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle; + + if (DEBUG) + { + problem_size = cutlass::gemm::BatchedGemmCoord( + problem_size.m(), problem_size.n(), problem_size.k(), 1); + } + + // Create host tensors that will be the backing store for the batches + // Note that no device memory is initially allocated + cutlass::HostTensor matrix_A({problem_size.m(), problem_size.k()}, false); + cutlass::HostTensor matrix_B({problem_size.k(), problem_size.n()}, false); + cutlass::HostTensor matrix_C_computed({problem_size.m(), problem_size.n()}, false); + cutlass::HostTensor matrix_C_reference({problem_size.m(), problem_size.n()}, false); + + // Reserve memory for the batch of tensors + matrix_A.reserve(problem_size.m()*problem_size.k()*problem_size.batch()); + matrix_B.reserve(problem_size.n()*problem_size.k()*problem_size.batch()); + matrix_C_computed.reserve(problem_size.m()*problem_size.n()*problem_size.batch()); + matrix_C_reference.reserve(problem_size.m()*problem_size.n()*problem_size.batch(), false); + + // Fill eatch tensor batch + const int seed = 9876; + for (int b = 0; b < problem_size.batch(); b++) + { + if(DEBUG) + { + cutlass::reference::host::BlockFillSequential( + matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity()); + cutlass::reference::host::BlockFillSequential( + matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity()); + } + else + { + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(b*matrix_A.capacity()), + seed + 1660, + 8, + -8, + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(b*matrix_B.capacity()), + seed + 1880, + 8, + -8, + 0 + ); + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity())); + cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity())); + } + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + ThreadBlockSwizzle swizzle; + + cutlass::gemm::BatchedGemmCoord tiled_size{ThreadBlockShape::kM, + ThreadBlockShape::kN, + problem_size.k(), // no split-k + DEBUG ? 1 : THREAD_B }; + + cutlass::gemm::BatchedGemmCoord tiled_shape = swizzle.get_tiled_shape(problem_size, tiled_size); + + #if 0 + printf("tiled_size = %d %d %d %d\n", tiled_size.m(), tiled_size.n(), tiled_size.k(), tiled_size.batch()); + printf("tiled_shape = %d %d %d %d\n", tiled_shape.m(), tiled_shape.n(), tiled_shape.k(), tiled_shape.batch()); + #endif + + // No split-k + EXPECT_EQ(tiled_size.k(), problem_size.k()); + + dim3 grid = swizzle.get_grid_shape(tiled_shape); + dim3 block(tiled_size.n() / ThreadShape::kN, tiled_size.batch(), tiled_size.k() / problem_size.k()); + + // Some sanity checks + EXPECT_TRUE( block.x*block.y*block.z <= 1024 ); + EXPECT_TRUE( block.x <= 1024 ); + EXPECT_TRUE( block.y <= 1024 ); + EXPECT_TRUE( block.z <= 64 ); + + #if 0 + printf("grid dim = %d, %d, %d\n", grid.x, grid.y, grid.z); + printf("block dim = %d, %d, %d\n", block.x, block.y, block.z); + #endif + + cudaError_t result; + cudaEvent_t start_event, end_event; + + for (int iter = 0; iter < (perf_test ? (perf_test_iter+1) : 1); ++iter) + { + if (perf_test && iter == 1) + { + result = cudaEventCreate(&start_event); + EXPECT_EQ(result, cudaSuccess); + + result = cudaEventCreate(&end_event); + EXPECT_EQ(result, cudaSuccess); + + result = cudaEventRecord(start_event); + EXPECT_EQ(result, cudaSuccess); + } + + if (beta == ElementCD(0)) + { + if (alpha == ElementCD(1)) + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + else + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + alpha, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + } + else + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + alpha, + beta, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + + if (iter == 0) + { + result = cudaGetLastError(); + EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); + } + } + + if (perf_test) + { + result = cudaEventRecord(end_event); + EXPECT_EQ(result, cudaSuccess); + } + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); + + if (perf_test) + { + float ms; + result = cudaEventElapsedTime(&ms, start_event, end_event); + EXPECT_EQ(result, cudaSuccess); + + double flops = (double(problem_size.m()) * + double(problem_size.n()) * + double(problem_size.k()) * + double(problem_size.batch()) * 2); // 2 for MAC + + double read_bytes = double(problem_size.batch()) * (sizeof(ElementA)*double(problem_size.m())*double(problem_size.k()) + + sizeof(ElementB)*double(problem_size.k())*double(problem_size.n())); + + double write_bytes = double(problem_size.batch()) * (sizeof(ElementCD)*double(problem_size.m())*double(problem_size.n())); + + double avg_runtime = double(ms) / perf_test_iter; + double gflops_per_sec = flops / 1.0e6 / avg_runtime; + double read_bandwidth = read_bytes / 1.0e6 / avg_runtime; + double write_bandwidth = write_bytes / 1.0e6 / avg_runtime; + + std::cout << "\n\nProblem size: " + << problem_size.m() + << " x " << problem_size.n() + << " x " << problem_size.k() + << " x " << problem_size.batch() + << std::endl; + + std::cout << " GFLOPs: " << gflops_per_sec << std::endl; + std::cout << "BW (R/W): " << read_bandwidth << " / " << write_bandwidth << " GB/sec" << std::endl; + std::cout << " Runtime: " << avg_runtime << " ms" << std::endl; + } + else + { + matrix_C_computed.sync_host(); + + // Compute the batched gemms + for (int b = 0; b < problem_size.batch(); b++) + { + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size.mnk(), alpha, + matrix_A.host_ref(b * matrix_A.capacity()), + matrix_B.host_ref(b * matrix_B.capacity()), beta, + matrix_C_reference.host_ref(b * matrix_C_computed.capacity())); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(b * matrix_C_computed.capacity()), + matrix_C_reference.host_view(b * matrix_C_reference.capacity())); + + EXPECT_TRUE(passed) + //<< "A:\n" << matrix_A.host_view() << "\n" + //<< "B:\n" << matrix_B.host_view() << "\n" + << "Batch: " << b << "\n" + << "Reference:\n" + << matrix_C_reference.host_view(b * matrix_C_reference.capacity()) + << "\n" + << "Computed:\n" + << matrix_C_computed.host_view(b * matrix_C_computed.capacity()) + << "\n"; + } + } +} + +template +void batched_gemv_kernel_perf_test(cutlass::gemm::BatchedGemmCoord problem_size, + ElementCD_ alpha = ElementCD_(1), + ElementCD_ beta = ElementCD_(0), + int iter = 50) +{ + batched_gemv_kernel_test(problem_size, alpha, beta, true, iter); +} + +} // namespace threadblock +} // namespace kernel +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h new file mode 100644 index 0000000000000000000000000000000000000000..6e3d6ab079d44345f2f55f4126ba3efc1eba47cb --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h @@ -0,0 +1,232 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/gemm/thread/mma.h" +#include "cutlass/layout/vector.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace test { +namespace gemm { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +void kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + using Btype = typename Mma::ElementB; + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK), false); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN), false); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run() { + + // + // initialize device memory + // + + cutlass::reference::host::detail::RandomUniformFunc< ElementA > tfill_rand_func( + 0, // seed + 10, // max + 0, // min + 0); // bits after decimal + + cutlass::reference::host::detail::TensorFillRandomUniformFunc< ElementA, LayoutA > tfill_rand( + tensor_A.host_view(), + tfill_rand_func); + + for (auto i=0; i< Shape::kM; i++) + for (auto j=0; j< Shape::kK; j++) + tfill_rand(cutlass::make_Coord(i,j)); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + + // Host side call + kernel( + tensor_D_computed.host_data(), + tensor_A.host_data(), + tensor_B.host_data(), + tensor_C.host_data()); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..8d34d7992b57cefa0eaf7300a5e1fb49f41a93e2 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/gemm/thread/mma.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace test { +namespace gemm { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +__global__ void kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run() { + + // + // initialize device memory + // + + cutlass::reference::host::BlockFillSequential( + tensor_A.host_data(), + tensor_A.capacity() + ); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + //tensor_D_reference.fill(tensor_C.host_view()); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..1f3bc8cf114d7eb2ac00bd19ae92c984558b7228 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h @@ -0,0 +1,435 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +namespace test { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma_sparse(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc, + typename Mma::IteratorE::Params params_E, + typename Mma::IteratorE::TensorRef ref_E) { + // Shared storage needed by threadblock-scoped matrix multiply- + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k() / Mma::kSparse}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + cutlass::MatrixCoord tb_offset_E{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k() / Mma::kSparse}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k() / Mma::kSparse}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorE iterator_E( + params_E, ref_E.data(), + {problem_size.m(), + problem_size.k() / Mma::kSparse / Mma::kElementsPerElementE}, + tb_thread_id, tb_offset_E); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_E, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct SparseTestbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ElementE = typename MmaCore::ElementE; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using ThreadMapE = typename MmaCore::IteratorThreadMapE; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + using AccessTypeE = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + static cutlass::arch::CacheOperation::Kind const CacheOpE = + MmaCore::kCacheOpE; + + static int const Sparse = MmaCore::kSparse; + static int const MetaSizeInBits = MmaCore::kMetaSizeInBits; + static int const MaxID2 = MmaCore::kMaxID2; + + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = typename MmaCore::GmemLayoutE; + + static int const ElementsPerElementE = MmaCore::kElementsPerElementE; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define iterators over tiles from the E operand + using IteratorE = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementE, ReorderedLayoutE, 1, ThreadMapE, AccessTypeE>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::SparseMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, + LayoutC, IteratorE, typename MmaCore::SmemIteratorE, CacheOpE, + typename MmaCore::MmaPolicy, Stages>; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_A_uncompressed; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_E; + cutlass::HostTensor matrix_E_reordered; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + SparseTestbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k / Sparse)); + matrix_A_uncompressed.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_E.reset(cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); + matrix_E_reordered.reset( + cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + + // Waive the test + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + matrix_E.host_view(), seed, MetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(matrix_E.host_view(), + (ElementE)(content)); + } else { + return false; + } + + cutlass::reorder_meta(matrix_E_reordered.host_ref(), matrix_E.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / Sparse / ElementsPerElementE}); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + matrix_E_reordered.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + typename IteratorE::Params params_E(matrix_E_reordered.layout()); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma_sparse, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return true; + } + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma_sparse, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + return true; + } + } + + test::gemm::threadblock::kernel_multistage_mma_sparse + <<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), params_E, + matrix_E_reordered.device_ref()); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::uncompress(matrix_A_uncompressed.host_ref(), matrix_A.host_ref(), + matrix_E.host_ref(), problem_size.m(), + problem_size.k()); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm(problem_size, ElementC(alpha), + matrix_A_uncompressed.host_view(), matrix_B.host_view(), + ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + + std::cout + << __FILE__ << ":" << __LINE__ << " " + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "E:\n" << matrix_E.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..5caaf38ace92758bbc86970d8d4ff339d87348ab --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h @@ -0,0 +1,372 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, + LayoutC, typename MmaCore::MmaPolicy, Stages>; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + } + + test::gemm::threadblock::kernel_multistage_mma + <<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0)); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::Gemm reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cout + << __FILE__ << ":" << __LINE__ << " " + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h new file mode 100644 index 0000000000000000000000000000000000000000..4e617d6327594570b1a88a5b28f2ec4d0467b534 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h @@ -0,0 +1,387 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC **ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + int lane_id = threadIdx.x; + + int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); + + int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_idx_mn % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_idx_mn / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, CacheOpA, + IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, LayoutC, + typename MmaCore::MmaPolicy, Stages>; + + static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed[kPartitionsK]; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_C_pointers; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); + + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); + + matrix_C_pointers.sync_device(); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributeMaxDynamicSharedMemorySize error: " + << cudaGetErrorString(result); + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributePreferredSharedMemoryCarveout error: " + << cudaGetErrorString(result); + } + + test::gemm::threadblock::kernel_multistage_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_pointers.device_data(), + matrix_C_computed[0].layout().stride(0)); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_host(); + + // TODO: this is temporary. it will be removed after slicing can de + // reduction + // + // Reduce matrix_C_computed + // + CUTLASS_PRAGMA_UNROLL + for(int k = 1; k < kPartitionsK; k++) { + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ + matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); + } + } + } + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_multistage_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed[0].host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..7eb62f9a39fe4472f77446efc591267001758c58 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h @@ -0,0 +1,353 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_, + /// Number of stages + int Stages = 2> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + static const int kStages = Stages; + + // Define iterators over tiles from the A operand + static const bool use_idp4a = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value; + + static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; + static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; + + using IteratorA = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> + >::type; + + // Define iterators over tiles from the B operand + using IteratorB = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> + >::type; + + // Define MmaPipeline Single Stage + using MmaPipelineSingleStage = cutlass::gemm::threadblock::MmaSingleStage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + // Define MmaPipeline Two Stages + using MmaPipelineTwoStages = cutlass::gemm::threadblock::MmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + // Define the threadblock-scoped pipelined matrix multiply (Select between Single vs. Two stages) + using Mma = typename cutlass::platform::conditional<(kStages==1), MmaPipelineSingleStage, MmaPipelineTwoStages>::type; + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_, float beta_) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + bool sufficient() { + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + test::gemm::threadblock::kernel_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0)); + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result) << " on device " << GetCudaDevice(); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed) << "Failed on device " << GetCudaDevice(); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h new file mode 100644 index 0000000000000000000000000000000000000000..36e55b2542b2258542336a052cdd14bf4b85f78d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h @@ -0,0 +1,370 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC **ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); + + + int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_idx_mn % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_idx_mn / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + + // Define iterators over tiles from the A operand + static const bool use_idp4a = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value; + + static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; + static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; + + using IteratorA = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> + >::type; + + // Define iterators over tiles from the B operand + using IteratorB = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> + >::type; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed[kPartitionsK]; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_C_pointers; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_, float beta_) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); + + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); + + matrix_C_pointers.sync_device(); + + test::gemm::threadblock::kernel_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_pointers.device_data(), + matrix_C_computed[0].layout().stride(0)); + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_host(); + + // TODO: this is temporary. it will be removed after slicing can de + // reduction + // + // Reduce matrix_C_computed + // + CUTLASS_PRAGMA_UNROLL + for(int k = 1; k < kPartitionsK; k++) { + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ + matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); + } + } + } + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed[0].host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e5fdc07769726353b33c1a5da65dedfadb4ce1e7 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma_planar_complex( + cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::Element *ptr_A, + int64_t imaginary_stride_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::Element *ptr_B, + int64_t imaginary_stride_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc, int64_t imaginary_stride_C) { + + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A operand + typename Mma::IteratorA iterator_A_real(params_A, ptr_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorA iterator_A_imag(params_A, ptr_A + imaginary_stride_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + // Construct iterators to B operand + typename Mma::IteratorB iterator_B_real(params_B, ptr_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorB iterator_B_imag(params_B, ptr_B + imaginary_stride_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum.real); + + iterator_C.store_with_pointer_offset(accum.imag, imaginary_stride_C); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename Mma_> +struct TestbedPlanarComplex { + + using Mma = Mma_; + using ThreadblockShape = typename Mma::Shape; + using IteratorA = typename Mma::IteratorA; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using IteratorB = typename Mma::IteratorB; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Mma::ElementC; + using ElementAccumulator = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + using ThreadMapA = typename Mma::IteratorA::ThreadMap; + using ThreadMapB = typename Mma::IteratorB::ThreadMap; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = Mma::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + Mma::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + Mma::kCacheOpB; + + // + // Data members + // + + cutlass::HostTensorPlanarComplex matrix_A; + cutlass::HostTensorPlanarComplex matrix_B; + cutlass::HostTensorPlanarComplex matrix_C_computed; + cutlass::HostTensorPlanarComplex matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + + // + // Methods + // + + /// Allocates workspace in device memory + TestbedPlanarComplex(int m, int n, int k) + : problem_size(m, n, k) { + + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + + } else if (init_A == cutlass::Distribution::Sequential) { + + for (int i = 0; i < matrix_A.capacity() * 2; ++i) { + matrix_A.host_data()[i] = cutlass::half_t(float(i % 5) - 2); + } + /* + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity() * 2); + */ + } else if (init_A == cutlass::Distribution::Identity) { + //cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + + + } else if (init_B == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity() * 2); + + for (int i = 0; i < matrix_B.capacity() * 2; ++i) { + matrix_B.host_data()[i] = cutlass::half_t(float((i + 3) % 5) - 2); + } + + + } else if (init_B == cutlass::Distribution::Identity) { + + //cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + + } else { + return false; + } + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + test::gemm::threadblock::kernel_mma_planar_complex<<>>( + problem_size, + params_A, + matrix_A.device_data(), + matrix_A.imaginary_stride(), + params_B, + matrix_B.device_data(), + matrix_B.imaginary_stride(), + matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), + matrix_C_computed.imaginary_stride() + ); + + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + cutlass::complex(ElementAccumulator(1)), + matrix_A.host_ref(), + Mma::kTransformA, + matrix_B.host_ref(), + Mma::kTransformB, + cutlass::complex(ElementAccumulator(0)), + matrix_C_reference.host_ref(), + matrix_C_reference.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), + matrix_C_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..921d1abdc40c2040104815cfffb8b2ea32384136 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h @@ -0,0 +1,1543 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_types.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/platform/platform.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +namespace test { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void kernel( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); + typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + FragmentA frag_A; + FragmentB frag_B; + + FragmentC accum; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(frag_A); + iter_B.load(frag_B); + + ++iter_A; + ++iter_B; + + mma(accum, frag_A, frag_B, accum); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The inner product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + + cutlass::reference::host::BlockFillRandomUniform(tensor_A.host_data(), + tensor_A.capacity(), seed, scope_max, scope_min, 0); + + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + + cutlass::reference::host::BlockFillRandomUniform(tensor_B.host_data(), + tensor_B.capacity(), seed, scope_max, scope_min, 0); + + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] + << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] + << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_ +> +struct TestbedComplex { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TestbedComplex() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), + seed, 8, -8, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), + seed + 16, 8, -8, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::GemmComplex( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + Mma::kTransformA, + tensor_B.host_ref(), + Mma::kTransformB, + ElementC(0), + tensor_C.host_ref(), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void kernel_transform( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + + using TransformedFragmentA = typename Mma::TransformedFragmentA; + using TransformedFragmentB = typename Mma::TransformedFragmentB; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); + typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + FragmentA loaded_frag_A; + FragmentB loaded_frag_B; + TransformedFragmentA transformed_frag_A; + TransformedFragmentB transformed_frag_B; + + FragmentC accum; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(loaded_frag_A); + iter_B.load(loaded_frag_B); + + ++iter_A; + ++iter_B; + + mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, + loaded_frag_B); + + mma(accum, transformed_frag_A, transformed_frag_B, accum); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct TransformTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformTestbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<>>( + tensor_D_computed.device_data(), tensor_A.device_data(), + tensor_B.device_data(), tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_ +> +struct TransformedTestbedComplex { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformedTestbedComplex() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), + seed, 8, -8, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), + seed + 16, 8, -8, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::GemmComplex( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + Mma::kTransformA, + tensor_B.host_ref(), + Mma::kTransformB, + ElementC(0), + tensor_C.host_ref(), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void sparse_kernel( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + typename Mma::ElementE const *input_E, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer + smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementE, Mma::Shape::kM * Mma::Shape::kK / + Mma::kSparse / Mma::kElementsPerElementE> + smem_buffer_E; + + __syncthreads(); + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + + typename Mma::ElementE *smem_ptr_E = smem_buffer_E.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_E.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_E, i) = + cutlass::ReferenceFactory::type>::get(input_E, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + using FragmentE = typename Mma::FragmentE; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed( + {ThreadblockShape::kM, ThreadblockShape::kK / Mma::kSparse}); + typename Mma::LayoutB layout_B = + Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + typename Mma::LayoutE layout_E = + Mma::LayoutE::packed({Mma::Shape::kM * Mma::kInterleaved, + Mma::Shape::kK / Mma::kSparse / + Mma::kElementsPerElementE / Mma::kInterleaved}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E}, cutlass::arch::LaneId()); + + FragmentA frag_A; + FragmentB frag_B; + + FragmentC accum; + + FragmentE frag_E; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(frag_A); + iter_B.load(frag_B); + iter_E.load(frag_E); + + ++iter_A; + ++iter_B; + ++iter_E; + + mma(accum, frag_A, frag_B, accum, frag_E); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct SparseTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + static int const Sparse = Mma::kSparse; + static int const MetaSizeInBits = Mma::kMetaSizeInBits; + static int const MaxID2 = Mma::kMaxID2; + static int const Interleaved = Mma::kInterleaved; + + using ElementE = typename Mma::ElementE; + + static int const ElementsPerElementE = Mma::kElementsPerElementE; + + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = + cutlass::layout::ColumnMajorInterleaved; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_uncompressed; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_E_reordered; + + // + // Methods + // + + /// Allocates workspace in device memory + SparseTestbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, + ThreadblockShape::kK / Sparse)); + tensor_A_uncompressed.reset( + cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_E.reset(cutlass::make_Coord( + Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); + tensor_E_reordered.reset(cutlass::make_Coord( + Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_E.host_view(), seed, MetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(tensor_E.host_view(), + (ElementE)(content)); + } else { + return false; + } + + cutlass::reorder_meta( + tensor_E_reordered.host_ref(), tensor_E.host_ref(), + {Shape::kM, Shape::kN, Shape::kK / Sparse / ElementsPerElementE}); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_E_reordered.sync_device(); + + // launch kernel + sparse_kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_E_reordered.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), + tensor_E.host_ref(), Shape::kM, Shape::kK); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A_uncompressed.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "A:\n" << tensor_A.host_view() << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "B:\n" << tensor_B.host_view() << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "E:\n" << tensor_E.host_view() << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h new file mode 100644 index 0000000000000000000000000000000000000000..3311e915db892466a9a4c52c82d100c2e1319966 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h @@ -0,0 +1,43 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace nvrtc { + +extern char const *kCutlassHeaders[]; +extern char const *kCutlassHeaderNames[]; +extern size_t const kCutlassHeaderCount; +} // namespace nvrtc +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55df44379c847034ed38cfab23477331ee4a537c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + + +namespace nvrtc { +namespace thread { + +template< + typename ElementA, typename ElementB, typename ElementC, + typename TileShape, typename ClusterShape, + bool kTransA, bool kTransB, + int RANK_M, int RANK_N, int RANK_K, int RANK_L +> +struct ContractionKernel { + +using ElementScalar = float; +using ElementAccum = float; +using EpilogueThread = cutlass::epilogue::thread::LinearCombination; + +static constexpr cute::GMMA::Major majorA = ! kTransA ? cute::GMMA::Major::MN : cute::GMMA::Major::K; +static constexpr cute::GMMA::Major majorB = ! kTransB ? cute::GMMA::Major::K : cute::GMMA::Major::MN; + +/// Kernel config +typedef int64_t stride_type; +typedef int32_t extent_type; + +static constexpr const stride_type* stride_null = nullptr; +static constexpr const extent_type* extent_null = nullptr; + +template +static constexpr +auto +make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) { + static_assert(Rank > 1); + if constexpr (IsMajor) { + return cute::transform(cute::make_seq{}, [&](auto i) { + if constexpr (i == 0) { + return cute::Int<1>{}; + } + else { + return i < n ? t[i] : init_default; + } + }); + } + else { + return cute::make_int_tuple(t, n, init_default); + } +} + +using StrideA = decltype(cute::make_stride( + make_stride_tuple(stride_null, 0, 0), + make_stride_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using StrideB = decltype(cute::make_stride( + make_stride_tuple(stride_null, 0, 0), + make_stride_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using StrideC = decltype(cute::make_stride( + cute::make_int_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using ProblemShape = decltype(cute::make_shape( + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0))); + +using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 16 / sizeof(ElementA), + ElementB, StrideB, 16 / sizeof(ElementB), + ElementAccum, + TileShape, ClusterShape, cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized +>::CollectiveOp; + +using EpilogueOutputOp = cutlass::epilogue::collective::DefaultEpilogue; +using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter; +using Kernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveOp, + CollectiveEpilogue>; + +}; + +} // namespace nvrtc +} // namespace thread diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..576f55cd868cd64c8c09c055d8b9a956e40c87ae --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/array.h" + +namespace test { +namespace nvrtc { +namespace kernel { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +__global__ void testbed_kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +} +} +} +} + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h new file mode 100644 index 0000000000000000000000000000000000000000..c7e6e94691c82b2f343959421c884c8b0b06f9b4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h @@ -0,0 +1,30 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h new file mode 100644 index 0000000000000000000000000000000000000000..5ba5432fd568af71e15b20b8cdab1571f303bcdf --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +typedef char int8_t; +typedef unsigned char uint8_t; +typedef short int16_t; +typedef unsigned short uint16_t; +typedef int int32_t; +typedef unsigned int uint32_t; +typedef long long int int64_t; +typedef unsigned long long int uint64_t; + +#if defined __x86_64__ && !defined __ILP32__ +# define __WORDSIZE 64 +#else +# define __WORDSIZE 32 +#endif + + +/* Small types. */ + +/* Signed. */ +typedef signed char int_least8_t; +typedef short int int_least16_t; +typedef int int_least32_t; +#if __WORDSIZE == 64 +typedef long int int_least64_t; +#else +__extension__ +typedef long long int int_least64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_least8_t; +typedef unsigned short int uint_least16_t; +typedef unsigned int uint_least32_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_least64_t; +#else +__extension__ +typedef unsigned long long int uint_least64_t; +#endif + + +/* Fast types. */ + +/* Signed. */ +typedef signed char int_fast8_t; +#if __WORDSIZE == 64 +typedef long int int_fast16_t; +typedef long int int_fast32_t; +typedef long int int_fast64_t; +#else +typedef int int_fast16_t; +typedef int int_fast32_t; +__extension__ +typedef long long int int_fast64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_fast8_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_fast16_t; +typedef unsigned long int uint_fast32_t; +typedef unsigned long int uint_fast64_t; +#else +typedef unsigned int uint_fast16_t; +typedef unsigned int uint_fast32_t; +__extension__ +typedef unsigned long long int uint_fast64_t; +#endif + +/* Types for `void *' pointers. */ +#if __WORDSIZE == 64 +# ifndef __intptr_t_defined +typedef long int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned long int uintptr_t; +#else +# ifndef __intptr_t_defined +typedef int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned int uintptr_t; +#endif + + +/* Largest integral types. */ +#if __WORDSIZE == 64 +typedef long int intmax_t; +typedef unsigned long int uintmax_t; +#else +__extension__ +typedef long long int intmax_t; +__extension__ +typedef unsigned long long int uintmax_t; +#endif + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..8fd6863e8fa003d3fbc4e0b498e3b9b454ade190 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h @@ -0,0 +1,398 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/gemm/thread/mma.h" +#include "../kernel/thread/testbed_kernel.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/trace.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include +#include +#include "../cutlass/nvrtc/environment.h" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace nvrtc { +namespace thread { + +#define NVRTC_RETURN_IF_ERROR(api) \ + do { \ + nvrtcResult _result = api; \ + if (_result != NVRTC_SUCCESS) { \ + CUTLASS_TRACE_HOST("Nvrtc error: " << _result); \ + return false; \ + } \ + } while(0) + +inline const char * cuda_source_fmt = R"""( + +#include "kernel/thread/contraction.hpp" + +using Operator = %s; + +extern "C" __global__ void global_entry(__grid_constant__ Operator::Params const params) { + extern __shared__ char smem[]; + + Operator op; + op(params, smem); +} + +)"""; + +struct TestbedKernel { + static bool compile(std::string const &kernel, std::vector const &opts) { + int sz = std::snprintf(nullptr, 0, cuda_source_fmt, kernel.c_str()); + std::vector cuda_source(sz + 1); + std::snprintf(&cuda_source[0], cuda_source.size(), cuda_source_fmt, kernel.c_str()); + + nvrtcProgram program; + NVRTC_RETURN_IF_ERROR( + nvrtcCreateProgram( + &program, + cuda_source.data(), + nullptr, + static_cast(cutlass::nvrtc::kCutlassHeaderCount), + cutlass::nvrtc::kCutlassHeaders, + cutlass::nvrtc::kCutlassHeaderNames) + ); + + nvrtcResult compile_result = + nvrtcCompileProgram( + program, + static_cast(opts.size()), + opts.data()); + + size_t log_size; + NVRTC_RETURN_IF_ERROR( + nvrtcGetProgramLogSize(program, &log_size) + ); + + if (log_size > 1) { + auto log = std::make_unique(log_size); + + NVRTC_RETURN_IF_ERROR( + nvrtcGetProgramLog(program, log.get()) + ); + + std::cout << log.get() << std::endl; + } + + NVRTC_RETURN_IF_ERROR(compile_result); + + NVRTC_RETURN_IF_ERROR( + nvrtcDestroyProgram(&program) + ); + + return true; + } +}; + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + static inline bool check_nvrtc_error(nvrtcResult error) { + if (error != NVRTC_SUCCESS) { + std::cerr << "failed to compile "; + return false; + } + return true; + } + + /// Runs the test + bool run(std::string const &gemm_traits) { + + // + // initialize device memory + // + + cutlass::reference::host::BlockFillSequential( + tensor_A.host_data(), + tensor_A.capacity() + ); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + +#if 0 + // launch kernel + cutlass::gemm::kernel::testbed_kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + +#else + // Instantiate gemm_kernel + nvrtcResult result_nvrtc; + nvrtcProgram program; + static char const *src = + "#include \"cutlass/gemm/thread/mma.h\"\n" + "#include \"cutlass/gemm/gemm.h\"\n" + "#include \"cutlass/layout/matrix.h\"\n" + "#include \"unit/nvrtc/kernel/thread/testbed_kernel.h\"\n" + ; + + std::string type_name; +#if 0 + // TODO Ideally we'd use nvrtcGetTypeName to determine the type, but it cannot resolve enum symbol names + // As altername solution we might want to implement to_string() to get the traits string. + nvrtcGetTypeName(&type_name); +#else + type_name = gemm_traits; +#endif + + result_nvrtc = nvrtcCreateProgram(&program, + src, + NULL, + (int)cutlass::nvrtc::kCutlassHeaderCount, + cutlass::nvrtc::kCutlassHeaders, + cutlass::nvrtc::kCutlassHeaderNames); + check_nvrtc_error(result_nvrtc); + + std::string gemm_kernel_instantiation = + "test::nvrtc::kernel::thread::testbed_kernel< " + type_name + " >"; + nvrtcAddNameExpression(program, gemm_kernel_instantiation.c_str()); + + const char *opts[] = {"--gpu-architecture=compute_75", + "--std=c++17", + "--include-path=/usr/local/cuda-10.1/include"}; + + result_nvrtc = nvrtcCompileProgram(program, 3, opts); + if (result_nvrtc != NVRTC_SUCCESS) { + size_t logSize; + nvrtcGetProgramLogSize(program, &logSize); + std::vector log(logSize); + nvrtcGetProgramLog(program, log.data()); + std::cout << "Compile log:" << std::endl << log.data() << std::endl; + } + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // The lowered name is the name of the template instantiation in the generated PTX code. + char const *gemm_kernel_lowered_name; + nvrtcGetLoweredName(program, gemm_kernel_instantiation.c_str(), &gemm_kernel_lowered_name); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // Query the size of the genereated PTX so that we can allocate storage and retrieve it afterwards + size_t ptx_size; + result_nvrtc = nvrtcGetPTXSize(program, &ptx_size); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + std::vector ptx(ptx_size); + result_nvrtc = nvrtcGetPTX(program, ptx.data()); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // we do not need the nvrtc program anymore + //nvrtcDestroyProgram(&program); + + CUmodule module; + CUresult result_cuda; + result_cuda = cuModuleLoadDataEx(&module, ptx.data(), 0, 0, 0); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } + + CUfunction kernel; + result_cuda = cuModuleGetFunction(&kernel, module, gemm_kernel_lowered_name); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } + + void* d_a = (void*)tensor_A.device_data(); + void* d_b = (void*)tensor_B.device_data(); + void* d_c = (void*)tensor_C.device_data(); + void* d_d = (void*)tensor_D_computed.device_data(); + void* args[] = { &d_d, &d_a, &d_b, &d_c }; + + // CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra + result_cuda = cuLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0, 0 /*cudaStreamDefault*/, args, 0); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } else { +} +#endif + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cout << "CUDA ERROR: " << cudaGetErrorString(result); + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + //tensor_D_reference.fill(tensor_C.host_view()); + + cutlass::reference::host::Gemm reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + if(!passed) std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + std::cout << "passed " << passed << std::endl; + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace nvrtc +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..6cc2946a2c51cfb8c1971345c81c1910bd667208 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h @@ -0,0 +1,145 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Common Testbed file shared by Pipeline unit tests +*/ + +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" +#include "../common/cutlass_unit_test.h" + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + #define CUTLASS_UNIT_TEST_PIPELINE true +#else + #define CUTLASS_UNIT_TEST_PIPELINE false +#endif + +// Command line test options +struct Options { + // + // Data Members + // + bool help; + bool verification_enabled; + int SM_count; + int clock_MHz; + + // + // Methods + // + Options(): + help(false), + verification_enabled(true), + SM_count(116), + clock_MHz(1477) + { } + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("verification-enabled", verification_enabled, true); + cmd.get_cmd_line_argument("sm-count", SM_count, 116); + cmd.get_cmd_line_argument("clock", clock_MHz, 1477); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --verification-enabled= Enable/Disable verification\n" + << " --sm-count= Number of SMs on the chip\n" + << " --clock= Locked clock value in Mhz\n"; + + return out; + } +}; + +// +// Testbed +// + +template +struct Testbed { +private: + // Commandline options + Options options; + + void run_test(uint32_t const kNumIters) { + + // Run CuTe Gemm + Pipeline pipeline; + + cudaError_t result = pipeline.run(kNumIters); + + CUTE_CHECK_LAST(); + } + + +public: + Testbed(Options const &options_) : options(options_) { + int device_id = 0; + cudaDeviceProp device_prop; + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + } + + /// Run verification Gemm problem sizes + bool verification() { + + std::array kNumIters; + + for (size_t i = 0; i < kNumIters.size(); ++i) { + kNumIters[i] = static_cast( (rand() % 1000) + 1 ); + } + + for (int n : kNumIters) { + std::cout << "Stages = " << Pipeline::Stages << " kNumIters = " << n << "\n"; + run_test(n); + } + + return true; + } +}; diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h new file mode 100644 index 0000000000000000000000000000000000000000..50a68a1437956c95aa4e7912e93adc8b1481c9cc --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed file used by cluster launch control pipeline unit test +*/ + +// + +// + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + #define CUTLASS_UNIT_TEST_PIPELINE true +#else + #define CUTLASS_UNIT_TEST_PIPELINE false +#endif + +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +// Command line test options +struct OptionsClusterLaunch { + // + // Data Members + // + bool help = false; + bool verification_enabled = true; + int SM_count = 116; + int clock_MHz = 1477; + dim3 grid_dim = {0,0,0}; + + // + // Methods + // + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("verification-enabled", verification_enabled, verification_enabled); + cmd.get_cmd_line_argument("sm-count", SM_count, SM_count); + cmd.get_cmd_line_argument("clock", clock_MHz, clock_MHz); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --verification-enabled= Enable/Disable verification\n" + << " --sm-count= Number of SMs on the chip\n" + << " --clock= Locked clock value in Mhz\n"; + + return out; + } +}; + +// +// Testbed +// + +template +class TestbedClusterLaunch { +private: + // Commandline options + OptionsClusterLaunch options; + + bool run_test() { + + // Run CuTe Gemm + Pipeline pipeline; + + bool success = false; + cudaError_t result = pipeline.run(success, this->options.grid_dim); + + CUTE_CHECK_LAST(); + return success; + } + + +public: + TestbedClusterLaunch(OptionsClusterLaunch const &options_) : options(options_) { + int device_id = 0; + cudaDeviceProp device_prop; + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + } + + /// Run verification Gemm problem sizes + bool verification() { + +#if !defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + printf( + "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be set, but it is not. \n" + "This test is waived.\n" + ); + return true; +#endif + +#if 0 + bool is_success = false; + for (int i = 0; i< 10; i++){ + printf("iteration = %d\n", i); + is_success = run_test(); + if ( not is_success ) + return is_success; + } + return is_success; +#else + // Run the test with single launch + return run_test(); +#endif + } +}; diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e44a42463ae95e4f76388d791c661de875092c93 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h @@ -0,0 +1,45 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level Reduction +*/ + +#pragma once + +#include "cutlass/reduction/thread/reduce.h" + +#include "cutlass/layout/vector.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..239f228831a25527106af1659383112535943df1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level Reduction +*/ + +#pragma once + +#include "cutlass/reduction/thread/reduce.h" + +#include "cutlass/layout/vector.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +namespace test { +namespace reduction { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the reduction +template < + /// Data type of elements + typename Element, + /// Number of elements + int N +> +struct Testbed_reduce_host { + + /// Thread-level reduction operator + using Reduce = cutlass::reduction::thread::Reduce< + cutlass::plus, + cutlass::Array + >; + + // + // Data members + // + + cutlass::Array tensor_in; + cutlass::Array reduced_tensor_computed; + cutlass::Array reduced_tensor_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed_reduce_host() { + tensor_in.clear(); + reduced_tensor_computed.clear(); + reduced_tensor_reference.clear(); + } + + /// Runs the test + bool run() { + + // + // initialize memory + // + + for(int i = 0; i < N; i++) + tensor_in.at(i) = Element(i); + + + Reduce reduce; + + cutlass::Array *out_ptr = &reduced_tensor_computed; + out_ptr[0] = reduce(tensor_in); + + // + // Reference implementation + // + Element e(0); + for (int i = 0; i < N; i++) + e = e + Element(i); + + reduced_tensor_reference.at(0) = e; + + // + // Verify equivalence + // + + // compare + bool passed = reduced_tensor_reference[0] == reduced_tensor_computed[0]; + + EXPECT_TRUE(passed) + << "Expected = " << float(reduced_tensor_reference.at(0)) << "\n\n" + << "Actual = " << float(reduced_tensor_computed.at(0)) << "\n\n" + << std::endl; + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level reduction kernel +template +__global__ void kernel_reduce(Element const *array_in, Element *result) { + + /// Thread-level reduction operator + using Reduce = cutlass::reduction::thread::Reduce< + cutlass::plus, + cutlass::Array + >; + + Reduce reduce; + + auto ptr_in = reinterpret_cast const *>(array_in); + auto result_ptr = reinterpret_cast *>(result); + auto in = *ptr_in; + result_ptr[0] = reduce(in); +} + + +/// Structure to compute the reduction +template < + /// Data type of elements + typename Element, + /// Number of elements + int N +> +struct Testbed_reduce_device { + + using Layout = cutlass::layout::PackedVectorLayout; + + // + // Data members + // + + cutlass::HostTensor tensor_in; + cutlass::HostTensor reduced_tensor_computed; + cutlass::HostTensor reduced_tensor_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed_reduce_device() { + + tensor_in.reset(cutlass::make_Coord(N), true); + reduced_tensor_computed.reset(cutlass::make_Coord(1), true); + reduced_tensor_reference.reset(cutlass::make_Coord(1), true); + } + + + /// Runs the test + bool run() { + + // + // initialize memory + // + + cutlass::reference::host::TensorFill( + tensor_in.host_view(), + Element(1) + ); + + cutlass::reference::host::TensorFill( + reduced_tensor_computed.host_view(), + Element(0) + ); + + cutlass::reference::host::TensorFill( + reduced_tensor_reference.host_view(), + Element(N) + ); + + tensor_in.sync_device(); + reduced_tensor_computed.sync_device(); + reduced_tensor_reference.sync_device(); + + /// call the kernel + kernel_reduce<<< dim3(1, 1), dim3(1, 1, 1) >>> ( + tensor_in.device_data(), + reduced_tensor_computed.device_data() + ); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + // Copy back results + reduced_tensor_computed.sync_host(); + + // Verify equivalence + bool passed = cutlass::reference::host::TensorEquals( + reduced_tensor_computed.host_view(), + reduced_tensor_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "Expected = " << reduced_tensor_reference.host_view() << "\n\n" + << "Actual = " << reduced_tensor_computed.host_view() << "\n\n" + << std::endl; + + return passed; + } +}; + +} // namespace thread +} // namespace reduction +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c4e7de4351076dba3a699b4cb1c8a6e01485bc20 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Compress utils specific for SM90 structure sparse kernels +*/ + +#pragma once + +#include // std::fill +#include // std::array +#include +#include // std::mt19937 + +#include "cute/container/bit_field.hpp" // cute::bit_field +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor, cute::print_tensor +#include "cutlass/arch/arch.h" // cutlass::arch::Sm90 +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" // cutlass::TagToStrideA_t +#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up +#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo +#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes +#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter + +namespace cutlass +{ +namespace transform +{ +namespace kernel +{ + +using namespace cute; + +namespace detail { + + template + CUTLASS_HOST_DEVICE + static uint8_t + encode_in_chunk_idx_legacy(int in_chunk_idx){ + if (sizeof(T) == 4) { + return in_chunk_idx == 0 ? 0b0100 : 0b1110; + } + else { + uint8_t res = 0; + if (in_chunk_idx == 0) { + res = 0b00; + } + else if (in_chunk_idx == 1) { + res = 0b01; + } + else if (in_chunk_idx == 2) { + res = 0b10; + } + else { + res = 0b11; + } + return res; + } + } + + template < + class SparseConfig, + class EngineA, + class LayoutA, + class EngineAc, + class LayoutAc + > + CUTLASS_HOST_DEVICE + static void + compress_two_chunks_legacy( + Tensor tensorA, + Tensor tensorAc, + uint8_t& meta_two_chunk, + int effective_elems) { + + using ElementA = typename EngineAc::value_type; + + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int ElementEBitsPerElementAMma = typename SparseConfig::ElementEBitsPerElementAMma{}; + static constexpr int LogicalSubChunk = ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalSubChunk = ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + /* + Legal metadata chunk in SM90 + Index Bin HEX + 0, 1 0b0100 4 + 1, 2 0b1001 9 + 2, 3 0b1110 E + 0, 2 0b1000 8 + 1, 3 0b1101 D + 0, 3 0b1100 C + 2, 1 0b0110 6 (Not used) + ----------------------------------- + TF32 + 0 0b0100 4 + 1 0b1110 E + */ + + if (effective_elems <= 0) { + return; + } + + // initialize + // 0 is the initial value for this function while 0x44 is the initial value for hardware. + meta_two_chunk = 0; + + for (int chunk_idx = 0; chunk_idx < 2; ++chunk_idx) { + // If Only One Chunk within this Two Chunk + if ( effective_elems <= chunk_idx * ElemsARawPerElementAMmaRaw * LogicalSubChunk ) { + break; + } + /// init result; + int non_zero_cnt = 0; + int32_t nnz_chunk_idx[PhysicalSubChunk] = { 0 }; + ElementA Ac_chunk[PhysicalSubChunk][ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + + for (int subchunk_idx = 0; subchunk_idx < LogicalSubChunk; ++subchunk_idx) { + bool is_nz = true; + ElementA subchunk_elems[ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + /// Check if subchunk is non-zero + for(int elem_idx = 0; elem_idx < ElemsARawPerElementAMmaRaw; elem_idx++) { + int offset = chunk_idx * LogicalElemsAPerChunk + subchunk_idx * ElemsARawPerElementAMmaRaw + elem_idx; + subchunk_elems[elem_idx] = offset < effective_elems ? tensorA(offset) : ElementA(0); + + ElementA zero = static_cast(0); + ElementA minus_zero = static_cast(ElementA(1) << cutlass::sizeof_bits_v - 1); + if (subchunk_elems[elem_idx] != zero && subchunk_elems[elem_idx] != minus_zero) { + if (non_zero_cnt >= PhysicalSubChunk) { + #ifdef __CUDA_ARCH__ + asm volatile ("brkpt;\n" ::); + #else + throw std::runtime_error("Found extra non-zero elements in a chunk!\n"); + #endif + } + is_nz = false; + } + } + + /// There is non-zero element in the subchunk + if(!is_nz) { + nnz_chunk_idx[non_zero_cnt] = subchunk_idx; + memcpy(Ac_chunk[non_zero_cnt], subchunk_elems, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + non_zero_cnt++; + } + } + + /* + Special cases + nnz == 1 and non-tf32 and nnz_idx = 3 + */ + ElementA elementA_zeros[ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + if constexpr (sizeof_bits_v < 32) { + if (non_zero_cnt == 1 && nnz_chunk_idx[0] == 3) { + memcpy(Ac_chunk[1], Ac_chunk[0], sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + memcpy(Ac_chunk[0], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + nnz_chunk_idx[1] = 3; + nnz_chunk_idx[0] = 0; + } + else if (non_zero_cnt == 1) { + memcpy(Ac_chunk[1], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + nnz_chunk_idx[1] = 3; + } + } + + /// Setup metadata + uint8_t meta_chunk = 0; + for (int i = 0; i < PhysicalSubChunk; i++) { + meta_chunk = static_cast(meta_chunk | (encode_in_chunk_idx_legacy(nnz_chunk_idx[i]) << (i * ElementEBitsPerElementAMma))); + for(int j = 0; j < ElemsARawPerElementAMmaRaw; j++) { + tensorAc(chunk_idx * PhysicalElemsAPerChunk + i * ElemsARawPerElementAMmaRaw + j) = Ac_chunk[i][j]; + } + } + meta_two_chunk = uint8_t(meta_two_chunk | (meta_chunk << (chunk_idx * _4{}))); + } + } +} + +template< + class ProblemShape_, + class ElementA_, + class LayoutATag_, + class SparseConfig_ +> +class SM90StructuredSparseCompressorLegacy { +public: + using SparseConfig = SparseConfig_; + using ProblemShape = ProblemShape_; + + // * EltA + using ElementA = ElementA_; + using ElementAUint = cute::uint_bit_t>; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + using ArrayElementA = cute::conditional_t>, + ElementA>; + using ElementAMma = typename SparseConfig::ElementAMma; + using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; + using ElementASparsity = typename SparseConfig::ElementASparsity; + using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; + using LayoutATag = LayoutATag_; + using LayoutA = LayoutATag; + using StrideA = cutlass::gemm::TagToStrideA_t; + + // * EltE + using ElementEMma = typename SparseConfig::ElementEMma; + using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; + using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; + + // * AtomE + using TensorEAtom = typename SparseConfig::TensorEAtom; + using TensorEAtomK = typename SparseConfig::TensorEAtomK; + using TensorEAtomM = typename SparseConfig::TensorEAtomM; + + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + // * Alignment + static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; + static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; + + // Required by `device_kernel` + static constexpr int MaxThreadsPerBlock = 1; + static constexpr int MinBlocksPerMultiprocessor = 1; + using ArchTag = arch::Sm90; + + struct SharedStorage { + /* empty, no smem needed */ + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct TransformArguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementA* ptr_ACompress{nullptr}; + ElementEMmaRaw* ptr_E{nullptr}; + }; + + using TransformParams = TransformArguments; + + struct Arguments { + ProblemShape problem_shape{}; + TransformArguments transform{}; + KernelHardwareInfo hw_info{}; + }; + + struct Params { + ProblemShape problem_shape{}; + TransformParams transform{}; + KernelHardwareInfo hw_info{}; + void* workspace = nullptr; + }; + + static Params + to_underlying_arguments(Arguments & args, void* workspace) { + return Params{{args.problem_shape}, + {args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E}, + {args.hw_info}, + workspace}; + } + + static Status + can_implement(Arguments const& args) { + auto [M, N, K, L] = args.problem_shape; + if (K % LogicalElemsAPerChunk != 0) { + CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size\n"); + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + static size_t + get_workspace_size(Arguments const& args) { + auto problem = args.problem_shape; + const int m = cute::size<0>(problem); + const int k = cute::size<2>(problem); + const int l = cute::size<3>(problem); + const int metadata_k = round_up(k, TensorEAlignmentK); + const int metadata_m = round_up(m, TensorEAlignmentM); + const int metadata_bytes = metadata_m * metadata_k / ElementEMmaSparsity{} * l; + return metadata_bytes; + } + + static Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + cudaError_t cuda_error; + + auto workspace_size = get_workspace_size(args); + if (workspace_size == 0) { + return Status::kSuccess; + } else if (workspace == nullptr) { + return Status::kErrorInternal; + } + + cudaPointerAttributes attri; + cuda_error = cudaPointerGetAttributes(&attri, workspace); + if (cuda_error != cudaSuccess) { + return Status::kErrorInternal; + } + + if ( attri.type == cudaMemoryTypeDevice ) { +#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER + CUTLASS_ASSERT(cuda_adapter); + if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast(0), workspace_size, stream)) { + return Status::kErrorInternal; + } +#else + cudaMemsetAsync(workspace, 0, workspace_size, stream); + cuda_error = cudaGetLastError(); + if (cuda_error != cudaSuccess) { + return Status::kErrorInternal; + } +#endif + } else { + memset(workspace, 0, workspace_size); + } + + return Status::kSuccess; + } + + static dim3 + get_grid_shape(Params const& params) { + return dim3(1, 1, 1); + } + + static dim3 + get_block_shape() { + return dim3(1, 1, 1); + } + + CUTE_HOST_DEVICE + void + operator()(Params params, char* smem_buf = nullptr) { + run(params, smem_buf); + } + + CUTE_HOST_DEVICE + static void + run(Params params, char* smem_buf = nullptr) { + do_compress_device_host(params); + } + +private: + + CUTE_HOST_DEVICE + static void + do_compress_device_host(Params params) { + auto [m, n, k, l] = params.problem_shape; + auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform; + auto workspace = params.workspace; + + const int aligned_k = (k + TensorAAlignmentK - 1) / TensorAAlignmentK * TensorAAlignmentK; + const int aligned_m = (m + TensorAAlignmentM - 1) / TensorAAlignmentM * TensorAAlignmentM; + const int metadata_k = (k + TensorEAlignmentK - 1) / TensorEAlignmentK * TensorEAlignmentK; + const int metadata_m = (m + TensorEAlignmentM - 1) / TensorEAlignmentM * TensorEAlignmentM; + const int k_compressed = aligned_k / ElementASparsity{}; + + // Convert to CuTe tensors. But don't want to use sparse_ptr, which is making everything complicated here. + cute::Tensor tensorA = make_tensor(recast_ptr(ptr_A), make_layout(make_shape(m, k, l), dA)); + + cute::Tensor tensorAc = make_tensor(recast_ptr(ptr_ACompress), + make_shape(aligned_m, k_compressed, l), + make_cute_packed_stride(StrideA{}, cute::make_shape(aligned_m, k_compressed, l))); + + cute::Tensor tensorE_raw_compress_logical = make_tensor(recast_ptr>(workspace), + make_shape(metadata_m, make_shape(TensorEAtomK{}, metadata_k / TensorEAtomK{}), l), + make_stride(TensorEAtomK{}, make_stride(_1{}, metadata_m*TensorEAtomK{}), metadata_m*metadata_k)); + + cute::Tensor tensorE_raw_compress = recast(tensorE_raw_compress_logical); + + // The following vars are all logical. + int atom_m = size<0>(TensorEAtom{}); + int atom_k = size<1>(TensorEAtom{}); + int tiled_m = metadata_m / atom_m; + int tiled_ke = metadata_k / atom_k; + // Col major when viewing atoms + int stride_tile_m = cosize(TensorEAtom{}); + int stride_tile_ke = atom_k * metadata_m; + + // Logical metadata tensor + cute::Tensor tensorE_logical = make_tensor(recast_ptr>(ptr_E), + make_layout(make_shape(append(shape<0>(TensorEAtom{}), tiled_m), + append(shape<1>(TensorEAtom{}), tiled_ke), + shape<2>(tensorE_raw_compress_logical)), + make_stride(append(stride<0>(TensorEAtom{}), stride_tile_m), + append(stride<1>(TensorEAtom{}), stride_tile_ke), + stride<2>(tensorE_raw_compress_logical)))); + // Physical metadata tensor + cute::Tensor tensorE = recast(tensorE_logical); + + // void do_init() + cute::clear(tensorAc); + cute::clear(tensorE_raw_compress); + + // void do_raw_compress() + using TileStepA = Int; + using TileStepAc = Int; + + cute::Tensor tensorATiled = logical_divide(tensorA, make_shape(_, TileStepA{}, _)); + cute::Tensor tensorAcTiled = logical_divide(tensorAc, make_shape(_, TileStepAc{}, _)); + + for (int batch_idx = 0; batch_idx < l; batch_idx++) { + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int tiler_k_idx = 0; tiler_k_idx < size<1,1>(tensorATiled); tiler_k_idx++) { + int effective_elems = cute::min(TileStepA{}, k - (tiler_k_idx * TileStepA{})); + detail::compress_two_chunks_legacy(tensorATiled(m_idx, make_coord(_, tiler_k_idx), batch_idx), + tensorAcTiled(m_idx, make_coord(_, tiler_k_idx), batch_idx), + tensorE_raw_compress(m_idx, tiler_k_idx, batch_idx), + effective_elems); + } + } + } + + // void do_reorder() + // Fast path when we don't permute. + if constexpr (sizeof_bits_v <= 8) { + memcpy(tensorE.data(), tensorE_raw_compress.data(), tensorE.size()); + } + else { + cute::copy(tensorE_raw_compress, tensorE); + } + + #if 0 + print("--> TensorA\n"); + auto tensorA_eltA = cute::recast(tensorA); + cute::print_tensor(tensorA_eltA); printf("\n\n"); + + print("--> REF TensorAC\n"); + auto tensorAc_eltA = cute::recast(tensorAc); + cute::print_tensor(tensorAc_eltA); printf("\n\n"); + + print("--> REF TensorE\n"); + cute::print_tensor(tensorE); printf("\n\n"); + #endif + + } +}; + +} // namespace kernel +} // namespace transform +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f44458244e0d3c4c80ecc29a0115cd6906211559 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp @@ -0,0 +1,877 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* + * @brief Test for structured sparse gemm compressor device kernel + */ + +#pragma once + +#include // cudaGetLastError + +#include // uint64_t +#include // printf +#include // malloc +#include // std::cout +#include +#include + +#include "cute/layout.hpp" // cute::make_shape +#include "cute/util/type_traits.hpp" // cute::is_same_v +#include "cutlass/coord.h" // cutlass::make_Coord +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/kernel_hardware_info.hpp" // cutlass::KernelHardwareInfo +#include "cutlass/layout/matrix.h" // cutlass::layout::Affine2Layout_Factory +#include "cutlass/numeric_types.h" // cutlass::sizeof_bits, cutlass::float_ +#include "cutlass/tensor_view.h" // cutlass::TensorView +#include "cutlass/transform/device/transform_universal_adapter.hpp" // cutlass::transform::device::TransformUniversalAdapter +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // cutlass::transform::kernel::StructuredSparseCompressorUtility +#include "cutlass/util/device_memory.h" // cutlass::device_memory::allocation +#include "cutlass/util/distribution.h" // cutlass::Distribution +#include "cutlass/util/host_tensor.h" // cutlass::HostTensor +#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride +#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals +#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform, TensorFillIdentity, TensorFillRandomGaussian, BlockFillSequential, TensorFill +#include "cutlass/detail/collective.hpp" + +#include "sm90_sparse_gemm_compressor_legacy.hpp" // Legacy host compressor +#include "../../common/cutlass_unit_test.h" // CUTLASS UT, EXPECT_TRUE + + +#define CUDA_CHECK_FALSE(cuda_error) \ + { \ + if (cuda_error != cudaSuccess) { \ + printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ + return false; \ + } \ + } + +#define CUDA_CHECK(cuda_error) \ + { \ + if (cuda_error != cudaSuccess) { \ + printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ + return; \ + } \ + } + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// * Test Bed +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test +{ +namespace transform +{ +namespace device +{ + +// Helper Functions +template +bool +initialize_tensor(cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) +{ + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else { + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Testbed +template +struct TestbedSparseGemmCompressor { +public: + using Compressor = Compressor_; + using CompressorKernel = typename Compressor::TransformKernel; + + using ElementA = typename CompressorKernel::ElementA; + using LayoutATag = typename CompressorKernel::LayoutATag; + using StrideA = typename CompressorKernel::StrideA; + using ArrayElementA = + ElementA + ; + + using ElementE = typename CompressorKernel::ElementEMmaRaw; + using LayoutETag = cutlass::layout::RowMajor; // We don't care about the major here, just to allocate tensor + + using SparseConfig = typename CompressorKernel::SparseConfig; + using ProblemShapeType = typename CompressorKernel::ProblemShape; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShapeType, + ElementA, + LayoutATag, + SparseConfig>; + + using CompressorKernelHost = cutlass::transform::kernel::SM90StructuredSparseCompressorLegacy< + ProblemShapeType, + ElementA, + LayoutATag, + SparseConfig>; + + using CompressorHost = cutlass::transform::device::TransformUniversalAdapter; + + static constexpr auto LogicalElemsAPerChunk = CompressorKernel::LogicalElemsAPerChunk; + static constexpr auto PhysicalElemsAPerChunk = CompressorKernel::PhysicalElemsAPerChunk; + + struct Data { + // Data Storage + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_A_Comp_ref; + cutlass::HostTensor tensor_E_ref; + }; + + struct CudaRAII { + cudaStream_t stream; + cudaEvent_t start; + cudaEvent_t stop; + + CudaRAII(){ + CUDA_CHECK(cudaStreamCreate( &stream )); + CUDA_CHECK(cudaEventCreate( &start )); + CUDA_CHECK(cudaEventCreate( &stop )); + }; + + CudaRAII(const CudaRAII&) = delete; + CudaRAII& operator=(const CudaRAII&) = delete; + CudaRAII(CudaRAII&&) = delete; + CudaRAII& operator=(CudaRAII&&) = delete; + + ~CudaRAII(){ + CUDA_CHECK(cudaStreamDestroy( stream )); + CUDA_CHECK(cudaEventDestroy( start )); + CUDA_CHECK(cudaEventDestroy( stop )); + } + }; + +public: + TestbedSparseGemmCompressor( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_Comp_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 7) + : init_A(init_A_) + , init_E(init_E_) + , init_A_Comp(init_A_Comp_) + , seed(seed_) + { + } + + bool valid_test(ProblemShapeType problem_shape_MNKL) + { + const int GemmK = cute::size<2>(problem_shape_MNKL); + + if ( GemmK % LogicalElemsAPerChunk != 0 ) { + printf("GemmK needs to be multiplier of LogicalElemsAPerChunk\n"); + return false; + } + + return true; + } + + bool initialize(ProblemShapeType problem_shape_MNKL, Data& datas) + { + CUDA_CHECK_FALSE(cudaGetLastError()); + + // In unit of ElementARaw + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + // Compressor utility to get allocated data size + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorA + // In unit of ElementARaw, after alignment requirement + // M-dim: no alignment requirement + // K-dim: multiplier of chunk size + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + const int GemmMAlignedAC = compressor_utility.get_tensorA_m_physical(); + const int GemmKAlignedAC = compressor_utility.get_tensorA_k_physical(); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + const int GemmMAlignedE = compressor_utility.get_metadata_m_physical(); + const int GemmKAlignedE = compressor_utility.get_metadata_k_physical(); + + auto a_coord = cutlass::make_Coord(GemmM * GemmL, GemmK); + auto e_coord = cutlass::make_Coord(GemmMAlignedE * GemmL, GemmKAlignedE); + auto a_comp_coord = cutlass::make_Coord(GemmMAlignedAC * GemmL, GemmKAlignedAC); + + typename LayoutATag::Stride stride_factor_A; + typename LayoutETag::Stride stride_factor_E; + + datas.tensor_A.resize(a_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + datas.tensor_A_Comp.resize(a_comp_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + datas.tensor_A_Comp_ref.resize(a_comp_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A), + false); + datas.tensor_E.resize(e_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + datas.tensor_E_ref.resize(e_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E), + false); + + EXPECT_TRUE(initialize_tensor(datas.tensor_A.host_view(), init_A, seed + 1)); + EXPECT_TRUE(initialize_tensor(datas.tensor_E.host_view(), init_E, seed + 2)); + EXPECT_TRUE(initialize_tensor(datas.tensor_E_ref.host_view(), init_E, seed + 3)); + EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp.host_view(), init_A_Comp, seed + 4)); + EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp_ref.host_view(), init_A_Comp, seed + 5)); + + compressor_utility.structure_sparse_zero_mask_fill(datas.tensor_A.host_data(), seed + 6); + + // Check for failed devide + CUDA_CHECK_FALSE(cudaGetLastError()); + + datas.tensor_A.sync_device(); + datas.tensor_A_Comp.sync_device(); + datas.tensor_E.sync_device(); + + // Check for failed devide + CUDA_CHECK_FALSE(cudaGetLastError()); + + return true; + } + + bool run_device(ProblemShapeType problem_shape_MNKL, Data& datas, float* time = nullptr) + { + CudaRAII cuda_raii; + + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {GemmM, GemmN, GemmK, GemmL}, + {datas.tensor_A.device_data(), + stride_a, + datas.tensor_A_Comp.device_data(), + datas.tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status {cutlass::Status::kSuccess }; + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + status = compressor_op.initialize(arguments, workspace.get(), cuda_raii.stream); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream)); + CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.start, cuda_raii.stream)); + + status = compressor_op.run(cuda_raii.stream); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.stop, cuda_raii.stream)); + CUDA_CHECK_FALSE(cudaEventSynchronize(cuda_raii.stop)); + CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream)); + if ( time != nullptr ){ + CUDA_CHECK_FALSE(cudaEventElapsedTime(time, cuda_raii.start, cuda_raii.stop)); + } + + datas.tensor_A_Comp.sync_host(); + datas.tensor_E.sync_host(); + + #if 0 + { + printf("\n--> DEVICE OUTPUT\n"); + printf("datas.tensor_A\n"); + std::cout << datas.tensor_A.host_view() << std::endl << std::endl; + printf("datas.tensor_A_Comp\n"); + std::cout << datas.tensor_A_Comp.host_view() << std::endl << std::endl; + printf("datas.tensor_E\n"); + std::cout << datas.tensor_E.host_view() << std::endl << std::endl; + } + #endif + + return true; + } + + bool run_host_ref(ProblemShapeType problem_shape_MNKL, Data& datas) + { + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + + typename CompressorKernelHost::Arguments arguments{ + {GemmM, GemmN, GemmK, GemmL}, + {datas.tensor_A.host_data(), + stride_a, + datas.tensor_A_Comp_ref.host_data(), + datas.tensor_E_ref.host_data()}, + {}}; + + const auto can_imp = CompressorKernelHost::can_implement(arguments); + if (can_imp != cutlass::Status::kSuccess) { + printf("can_implement() check failed\n"); + return false; + } + + // Relies on std::vector for RAII + auto workspace_size = + static_cast::size_type>(CompressorKernelHost::get_workspace_size(arguments)); + std::vector workspace_vector(workspace_size); + auto workspace = static_cast(workspace_vector.data()); + + cutlass::Status status = CompressorKernelHost::initialize_workspace(arguments, workspace); + if (status != cutlass::Status::kSuccess) { + printf("initialize_workspace() failed\n"); + return false; + } + + auto params = CompressorKernelHost::to_underlying_arguments(arguments, workspace); + CompressorKernelHost::run(params); + + return true; + } + + bool compare_reference(Data& datas) + { + bool check_tensor_a_compressed = + cutlass::reference::host::TensorEquals(datas.tensor_A_Comp_ref.host_view(), datas.tensor_A_Comp.host_view()); + if (!check_tensor_a_compressed) { + printf("A-Compressed Mismatch\n"); + } + + bool check_tensor_e = cutlass::reference::host::TensorEquals(datas.tensor_E_ref.host_view(), datas.tensor_E.host_view()); + if (!check_tensor_e) { + printf("E Mismatch\n"); + } + + return check_tensor_a_compressed && check_tensor_e; + } + + bool run_auto_small() + { + return run_auto(true); + } + + bool run_auto(bool run_small = false) + { + constexpr auto TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + constexpr auto TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + + constexpr int GemmN = 1; + + using ProblemType = typename std::array; + + std::vector problems; + + const std::vector problems_multiplier_of_tensor_e_atom = { + // * Regular Cases (multiplier of TensorEAlignment) + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 3}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 3}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 3}, + }; + + const std::vector problems_multiplier_of_tensor_e_atom_large = { + // * Large Case (multiplier of TensorEAlignment) + {TensorEAlignmentM * 10, GemmN, TensorEAlignmentK * 13, 1}, + // {TensorEAlignmentM * 11, GemmN, TensorEAlignmentK * 14, 2}, + // {TensorEAlignmentM * 12, GemmN, TensorEAlignmentK * 15, 3}, + }; + + const std::vector problems_multiplier_of_twochunk { + // * Corner Cases + {4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + }; + + const std::vector problems_multiplier_of_onechunk { + {4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + }; + + // Run small only run multiplier of chunk size cases + if (run_small) { + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end()); + } + // Run full run all corner cases + else { + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom_large.begin(), problems_multiplier_of_tensor_e_atom_large.end()); + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end()); + problems.insert(problems.end(), problems_multiplier_of_twochunk.begin(), problems_multiplier_of_twochunk.end()); + problems.insert(problems.end(), problems_multiplier_of_onechunk.begin(), problems_multiplier_of_onechunk.end()); + } + + for (const auto& problem_shape_MNKL : problems) { + const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL; + bool passed = run({GemmM, GemmN, GemmK, GemmL}); + printf("run() (%.4d,%.4d,%.4d,%.4d) %s\n", GemmM, GemmN, GemmK, GemmL, passed ? "PASS" : "FAIL"); + CUTLASS_TRACE_HOST("run() " << GemmM << " " << GemmN << " " << GemmK << " " << GemmL << passed ? " PASS" : " FAIL"); + if (not passed) { + return false; + } + } + + return true; + } + + bool run(ProblemShapeType problem_shape_MNKL) + { + // Check if valid test + if (not valid_test(problem_shape_MNKL)) { + CUTLASS_TRACE_HOST("valid_test() fail\n"); + return false; + } + + // Data Storage + Data datas; + + // Initialize Data + if (not initialize(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("initialize() fail\n"); + return false; + } + + // Run Compressor (Host Ref) + if (not run_host_ref(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("run_host() fail\n"); + return false; + } + + // Run Compressor (Device) + if (not run_device(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("run_device() fail\n"); + return false; + } + + // Verify + if (not compare_reference(datas)) { + CUTLASS_TRACE_HOST("compare_reference() DEVICE <-> LEGACY HOST fail\n"); + printf("compare_reference() DEVICE <-> LEGACY HOST fail\n"); + return false; + } + // else { + // printf("DEVICE <-> HOST PASS\n"); + // } + + return true; + } + + bool benchmark(ProblemShapeType problem_shape_MNKL) { + const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL; + printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) START\n", GemmM, GemmN, GemmK, GemmL); + + // Check if valid test + if (valid_test(problem_shape_MNKL) == false) { + CUTLASS_TRACE_HOST("valid_test() fail\n"); + return false; + } + + // 2 warm-up iterations and 10 timing iterations + constexpr int num_warmup = 5; + constexpr int num_iter = 10; + + // Duplicate data to mimic cold cache + Data data[num_warmup + num_iter]; + double total_time_milliseconds{0.0}; + + for (int i = 0; i < num_warmup + num_iter; ++i ) { + printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) ITER %d\n", GemmM, GemmN, GemmK, GemmL, i ); + + auto& datum_i = data[i]; + + // Initialize Data + if (initialize(problem_shape_MNKL, datum_i) == false) { + CUTLASS_TRACE_HOST("initialize() fail\n"); + return false; + } + + // Run Compressor (Device) + double time_i_milliseconds{0.0f}; + if (not run_device(problem_shape_MNKL, datum_i, &time_i_milliseconds)) { + CUTLASS_TRACE_HOST("run_device() fail\n"); + return false; + } + + if ( i >= num_warmup ) { + total_time_milliseconds += time_i_milliseconds; + } + } + + const double mean_time_milliseconds = total_time_milliseconds / num_iter; + printf("Mean time (ms): %.5f\n", mean_time_milliseconds); + + return true; + } + +public: + // Data Init Setting + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_A_Comp; + cutlass::Distribution::Kind init_E; + uint64_t seed; +}; + +} // namespace device +} // namespace transform +} // namespace test diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h new file mode 100644 index 0000000000000000000000000000000000000000..df241e3ca6e6e584af7351402d990a8028e2abed --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/arch/arch.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ArchMap; + +template <> struct ArchMap { + static int const kMin = 50; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 60; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 61; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 75; +}; + +template struct ArchMap { + static int const kMin = 75; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 80; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 86; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 89; + static int const kMax = 100; +}; + +template struct ArchMap { + static int const kMin = 90; + static int const kMax = 1024; +}; + +// Arch conditional WGMMA +template <> struct ArchMap { + static int const kMin = 90; + static int const kMax = 90; +}; + +// Arch conditional sparse WGMMA +template <> struct ArchMap { + static int const kMin = 90; + static int const kMax = 90; +}; + + +template struct ArchMap { + static int const kMin = 100; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 100; + #if (__CUDACC_VER_MAJOR__ >= 13) + static int const kMax = 110; + #else + static int const kMax = 103; + #endif // __CUDACC_VER_MAJOR__ >= 13 +}; + +template struct ArchMap { + static int const kMin = 103; + static int const kMax = 1024; +}; +template <> struct ArchMap { + static int const kMin = 103; + static int const kMax = 103; +}; + +template struct ArchMap { + static int const kMin = 120; + static int const kMax = 121; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h new file mode 100644 index 0000000000000000000000000000000000000000..5e80c124e59d24cd90c7c1b0c06bcc3bedfee62f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h @@ -0,0 +1,815 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct MathInstructionDescription { + + /// Shape of the target math instruction + cutlass::gemm::GemmCoord instruction_shape; + + /// Describes the data type of the internal accumulator + NumericTypeID element_accumulator; + + /// Classification of math instruction + OpcodeClassID opcode_class; + + /// Type of math operation performed + MathOperationID math_operation; + + // + // Methods + // + + MathInstructionDescription( + cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), + NumericTypeID element_accumulator = NumericTypeID::kInvalid, + OpcodeClassID opcode_class = OpcodeClassID::kInvalid, + MathOperationID math_operation = MathOperationID::kMultiplyAdd + ): + instruction_shape(instruction_shape), + element_accumulator(element_accumulator), + opcode_class(opcode_class), + math_operation(math_operation) {} + + // Equality operator + inline + bool operator==(MathInstructionDescription const& rhs) const{ + return ( + (instruction_shape == rhs.instruction_shape) && + (element_accumulator == rhs.element_accumulator) && + (opcode_class == rhs.opcode_class) && + (math_operation == rhs.math_operation)); + } + + // Inequality operator + inline + bool operator!=(MathInstructionDescription const& rhs) const { + return !(*this == rhs); + } + +}; + +/// Structure describing the tiled structure of a GEMM-like computation +struct TileDescription { + + /// Describes the shape of a threadblock (in elements) + cutlass::gemm::GemmCoord threadblock_shape; + + /// Describes the number of pipeline stages in the threadblock-scoped mainloop + int threadblock_stages; + + /// Number of warps in each logical dimension + cutlass::gemm::GemmCoord warp_count; + + /// Core math instruction + MathInstructionDescription math_instruction; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. + int minimum_compute_capability; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. + int maximum_compute_capability; + + /// Describes the shape of a cluster (in blocks) + cutlass::gemm::GemmCoord cluster_shape; + + // + // Methods + // + + TileDescription( + cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(), + int threadblock_stages = 0, + cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), + MathInstructionDescription math_instruction = MathInstructionDescription(), + int minimum_compute_capability = 0, + int maximum_compute_capability = 0, + cutlass::gemm::GemmCoord cluster_shape = cutlass::gemm::GemmCoord(1,1,1) + ): + threadblock_shape(threadblock_shape), + threadblock_stages(threadblock_stages), + warp_count(warp_count), + math_instruction(math_instruction), + minimum_compute_capability(minimum_compute_capability), + maximum_compute_capability(maximum_compute_capability), + cluster_shape(cluster_shape) { } + + // Equality operator + inline + bool operator==(TileDescription const& rhs) const{ + return ( + (threadblock_shape == rhs.threadblock_shape) && + (threadblock_stages == rhs.threadblock_stages) && + (warp_count == rhs.warp_count) && + (math_instruction == rhs.math_instruction) && + (minimum_compute_capability == rhs.minimum_compute_capability) && + (maximum_compute_capability == rhs.maximum_compute_capability)); + } + + // Inequality operator + inline + bool operator!=(TileDescription const& rhs) const { + return !(*this == rhs); + } +}; + +/// High-level description of an operation +struct OperationDescription { + + /// Unique identifier describing the operation + char const * name; + + /// Operation provider + Provider provider; + + /// Kind of operation + OperationKind kind; + + /// Describes the tiled structure of a GEMM-like computation + TileDescription tile_description; + + // + // Methods + // + OperationDescription( + char const * name = "unknown", + Provider provider = Provider::kInvalid, + OperationKind kind = OperationKind::kInvalid, + TileDescription const& tile_description = TileDescription() + ): + name(name), provider(provider), kind(kind), tile_description(tile_description) { } +}; + +/// Structure describing the properties of a tensor +struct TensorDescription { + + /// Numeric type of an individual element + NumericTypeID element; + + /// Enumerant identifying the layout function for the tensor + LayoutTypeID layout; + + /// Alignment restriction on pointers, strides, and extents + int alignment; + + /// log2() of the maximum extent of each dimension + int log_extent_range; + + /// log2() of the maximum value each relevant stride may have + int log_stride_range; + + // + // Methods + // + + TensorDescription( + NumericTypeID element = NumericTypeID::kInvalid, + LayoutTypeID layout = LayoutTypeID::kInvalid, + int alignment = 1, + int log_extent_range = 24, + int log_stride_range = 24 + ): + element(element), + layout(layout), + alignment(alignment), + log_extent_range(log_extent_range), + log_stride_range(log_stride_range) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all GEMM computations +struct GemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the sparse meta matrices + TensorDescription E; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + GemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + GemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +struct BlockScaleDescription { + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the SFD operand + TensorDescription SFD; + + /// Describes the input ScaleFactor VectorSize + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + + /// Describes the Output ScaleFactor VectorSize + int EpilogueSFVecSize; + + /// Describes the underlying kind of scaling: + /// Tensor Core supported (BlockScaled) or manual scaling (Blockwise) + OperationKind kind; +}; + +struct GroupedGemmDescription : public OperationDescription { + GemmDescription gemm; + std::optional block_scales; +}; + +/// Description of all GEMM computations +struct BlockScaledGemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the SFD operand + TensorDescription SFD; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + /// Describes the input ScaleFactor VectorSize + int SFVecSize; + + /// Describes the Output ScaleFactor VectorSize + int EpilogueSFVecSize; + + // + // Methods + // + + BlockScaledGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + BlockScaledGemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +/// Description of all GEMM computations +struct BlockwiseGemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + /// Describes the input ScaleFactor VectorSize + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + + // + // Methods + // + + BlockwiseGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + BlockwiseGemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description for structured sparse GEMMs. +struct SparseGemmDescription : public GemmDescription { + + /// Description structure for structured sparse GEMM + SparseGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + TensorDescription const& E = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + GemmDescription(gemm_kind, A, B, C, D, element_epilogue, split_k_mode, transform_A, transform_B) + {this->E = E;} +}; + +/// Description of all Reduction operations +struct ReductionDescription : public OperationDescription { + + /// Describes the data type of workspace + NumericTypeID element_workspace; + + /// Describes the data type of final output + NumericTypeID element_output; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; +}; + +/// Description of all Rank K update computations (SYRK, HERK, SYR2K, HER2K) +struct RankKDescription : public OperationDescription { + + /// Indicates which device template is used (universal or regular) + RankKKind rank_k_kind; + + /// Number of rank update (rank k or rank 2k) + int num_ranks; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand (used only for SYR2K and HER2K) + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription C; + + /// Describes the fill mode for matrix C + FillMode fill_mode; + + /// Describes the blas mode (symmetric/hermitian) + BlasMode blas_mode; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + RankKDescription( + RankKKind rank_k_kind = RankKKind::kUniversal, + int num_ranks = 1, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + FillMode fill_mode = FillMode::kInvalid, + BlasMode blas_mode = BlasMode::kInvalid, + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + rank_k_kind(rank_k_kind), + num_ranks(num_ranks), + A(A), + B(B), + C(C), + fill_mode(fill_mode), + blas_mode(blas_mode), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all TRMM computations +struct TrmmDescription : public OperationDescription { + + /// Indicates the kind of TRMM performed + TrmmKind trmm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the side mode for matrix A + SideMode side_mode; + + /// Describes the fill mode for matrix A + FillMode fill_mode; + + /// Describes the diag type for matrix A + DiagType diag_type; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription D; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + // + // Methods + // + + TrmmDescription( + TrmmKind trmm_kind = TrmmKind::kUniversal, + TensorDescription const& A = TensorDescription(), + SideMode side_mode = SideMode::kInvalid, + FillMode fill_mode = FillMode::kInvalid, + DiagType diag_type = DiagType::kInvalid, + TensorDescription const& B = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone + ): + trmm_kind(trmm_kind), + A(A), + side_mode(side_mode), + fill_mode(fill_mode), + diag_type(diag_type), + B(B), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all SYMM/HEMM update computations +struct SymmDescription : public OperationDescription { + + /// Indicates which device template is used (universal or regular) + SymmKind symm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription C; + + /// Describes the side mode for matrix A + SideMode side_mode; + + /// Describes the fill mode for matrix A + FillMode fill_mode; + + /// Describes the blas mode (symmetric/hermitian) + BlasMode blas_mode; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + SymmDescription( + SymmKind symm_kind = SymmKind::kUniversal, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + SideMode side_mode = SideMode::kInvalid, + FillMode fill_mode = FillMode::kInvalid, + BlasMode blas_mode = BlasMode::kInvalid, + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + symm_kind(symm_kind), + A(A), + B(B), + C(C), + side_mode(side_mode), + fill_mode(fill_mode), + blas_mode(blas_mode), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all Conv2d operations +struct ConvDescription : public OperationDescription { + /// Describes the convolution dimension support (2D or 3D) + int conv_dim; + + /// Describes the kind of convolution + ConvKind conv_kind; + + /// Describes the type of iterator algorithm (analytic or precomputed) + IteratorAlgorithmID iterator_algorithm; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the C operand + TensorDescription C; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + // + // Methods + // + // Returns Activation TensorDescription + TensorDescription activation() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return A; + case library::ConvKind::kDgrad : return C; + case library::ConvKind::kWgrad : return B; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Filter TensorDescription + TensorDescription filter() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return B; + case library::ConvKind::kDgrad : return B; + case library::ConvKind::kWgrad : return C; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Output TensorDescription + TensorDescription output() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return C; + case library::ConvKind::kDgrad : return A; + case library::ConvKind::kWgrad : return A; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h new file mode 100644 index 0000000000000000000000000000000000000000..027944eb6ac8c6e8f250d83ed33c0899adfbd3e8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h @@ -0,0 +1,365 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief BLAS-like handle used to launch operations on the CUDA device. +*/ + +#pragma once + +#include +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Handle object +class Handle { +private: + + /// Host workspace + static int const kHostWorkspaceSize = (4 << 10); + + /// Provider of operations + Provider provider_; + + /// CUDA device properties + cudaDeviceProp device_; + + /// CUDA stream + cudaStream_t stream_; + + /// Device workspace + void *workspace_; + + /// Size of device workspace in bytes + size_t workspace_size_; + + /// Indicates whether scalars are host or device pointers + ScalarPointerMode scalar_pointer_mode_; + + /// Pointer to the most recently executed operation + Operation const *last_operation_; + + int device_idx_; + +public: + + /// Constructor + Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20)); + + /// Destructor + ~Handle(); + + /// Move constructor + Handle(Handle && handle); + + /// Move assignment operator + Handle &operator=(Handle && handle); + + // + // Persistent state accessors + // + + /// Returns compute capability of the selected device + int compute_capability() const; + + /// Sets the current CUDA stream + void set_stream(cudaStream_t stream); + + /// Gets the current CUDA stream + cudaStream_t get_stream() const; + + /// Gets the current provider + Provider get_provider() const; + + /// Sets the provider of operations + void set_provider(Provider provider); + + /// Gets the device workspace size + size_t get_workspace_size() const; + + /// Gets a pointer to the device workspace allocation in Global Memory + void *get_workspace() const; + + /// Sets the size of device workspace, invalidating calls to get_device_workspace() + void set_workspace_size(size_t bytes); + + /// Gets the scalar pointer mode + ScalarPointerMode get_scalar_pointer_mode() const; + + /// Sets the scalar pointer mode + void set_scalar_pointer_mode(ScalarPointerMode mode); + + /// Gets the most recently executed operation + Operation const *get_last_operation() const; + + // + // Computations + // + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C + Status gemm( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + + void const * ptr_A, /// Pointer to A matrix in Global Memory + int64_t lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + + void const * ptr_B, /// Pointer to B matrix in Global Memory + int64_t ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrices + + void const * ptr_C, /// Pointer to C matrix + int64_t ldc, /// Leading dimension of C matrix + + void * ptr_D, /// Pointer to D matrix + int64_t ldd /// Leading dimension of D matrix + ); + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C. + // + // Supports batched-strided, batched array or split-K serial or split-K parallel. + // + Status gemm_universal( + + GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + int cluster_m, /// cluster shape M dimension + int cluster_n, /// cluster shape N dimension + int cluster_k, /// cluster shape K dimension + int cluster_m_fallback, /// Fallback cluster shape M dimension + int cluster_n_fallback, /// Fallback cluster shape N dimension + int cluster_k_fallback, /// Fallback cluster shape K dimension + + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + void const * ptr_A, /// Pointer to A matrix in Global Memory + int64_t lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + void const * ptr_B, /// Pointer to B matrix in Global Memory + int64_t ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C matrix + LayoutTypeID layout_C, /// Layout of D matrix + void const * ptr_C, /// Pointer to C matrix + int64_t ldc, /// Leading dimension of C matrix + + NumericTypeID element_D, /// Data type of D matrix + LayoutTypeID layout_D, /// Layout of D matrix + void * ptr_D, /// Pointer to D matrix + int64_t ldd, /// Leading dimension of D matrix + + int batch_count = 1, /// Batch count or number of split-K slices + + int64_t batch_stride_A = 0, /// Batch stride of A operand + int64_t batch_stride_B = 0, /// Batch stride of B operand + int64_t batch_stride_C = 0, /// Batch stride of C operand + int64_t batch_stride_D = 0 /// Batch stride of D operand + ); + + /// Planar complex GEMM + /// + /// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel. + /// + Status gemm_planar_complex( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * ptr_A_real, /// Pointer to real part of A matrix + void const * ptr_A_imag, /// Pointer to imaginary part of A matrix + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * ptr_B_real, /// Pointer to real part of B matrix + void const * ptr_B_imag, /// Pointer to imaginary part of B matrix + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * ptr_C_real, /// Pointer to real part of C matrix + void const * ptr_C_imag, /// Pointer to imaginary part of C matrix + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * ptr_D_real, /// Pointer to real part of D matrix + void * ptr_D_imag, /// Pointer to imaginary part of D matrix + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix + + int batch_count = 1, /// Number of batched GEMMs to execute + + int64_t batch_stride_A_real = 0, + int64_t batch_stride_A_imag = 0, + + int64_t batch_stride_B_real = 0, + int64_t batch_stride_B_imag = 0, + + int64_t batch_stride_C_real = 0, + int64_t batch_stride_C_imag = 0, + + int64_t batch_stride_D_real = 0, + int64_t batch_stride_D_imag = 0 + ); + + /// Planar complex GEMM loading pointers from arrays in global memory + Status gemm_planar_complex_array( + + int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) + int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) + int expected_K, /// Expected GEMM K dimension + int batch_count, /// Number of independent GEMM computations to execute + + int const *M, /// Array containing the GEMM M dimension for each batch index + int const *N, /// Array containing the GEMM N dimension for each batch index + int const *K, /// Array containing the GEMM K dimension for each batch index + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices + void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices + + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices + void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices + + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices + void const * const * ptr_C_imag, /// Pointer to array containing pointers to imaginary part of C matrices + + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices + void * const * ptr_D_imag, /// Pointer to array containing pointers to imaginary part of D matrices + + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag /// Leading dimension of imaginary part of D matrix + ); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Unique pointer storing the handle +using HandlePtr = std::unique_ptr; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace +Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation); +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace +Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation); +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h new file mode 100644 index 0000000000000000000000000000000000000000..6764d9a6d81286c8bba0f5184b17819bfae86978 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h @@ -0,0 +1,995 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#ifndef CUTLASS_LIBRARY_LIBRARY_H +#define CUTLASS_LIBRARY_LIBRARY_H + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/library/types.h" +#include "cutlass/library/descriptions.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/blas3.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mode of Universal GEMM +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Base class for all operations +class Operation { +public: + + virtual ~Operation() { } + + virtual OperationDescription const & description() const = 0; + + virtual Status can_implement( + void const *configuration, + void const *arguments) const = 0; + + virtual uint64_t get_host_workspace_size( + void const *configuration) const = 0; + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const = 0; + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const = 0; + + // Originally designed for metadata, but should be useful for FP8/6/4 too. + virtual Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspace_ptrs, + int problem_count, + cudaStream_t stream = nullptr) { + return Status::kErrorNotSupported; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const = 0; + + // Set arguments that should only be set once before verifying or profiling the kernel. + // This should encompass any expensive operations that don't vary from run to run + // (e.g., max_active_clusters). + virtual Status initialize_with_arguments(void* arguments_ptr) const { + return Status::kSuccess; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic GEMM operations +// +// OperationKind: Gemm +// GemmKind: Gemm +// +struct GemmConfiguration { + + /// GEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Number of partitions of K dimension + int split_k_slices{0}; +}; + +/// Arguments for GEMM +struct GemmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for batched GEMM in which multiple matrix products are computed +// +// OperationKind: Gemm +// GemmKind: Batched + +struct GemmBatchedConfiguration { + + /// GEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Stride between instances of the A matrix in memory + int64_t batch_stride_A{0}; + + /// Stride between instances of the B matrix in memory + int64_t batch_stride_B{0}; + + /// Stride between instances of the C matrix in memory + int64_t batch_stride_C{0}; + + /// Stride between instances of the D matrix in memory + int64_t batch_stride_D{0}; + + /// Number of GEMMs in batch + int batch_count{1}; +}; + +/// Arguments to batched GEMM +using GemmBatchedArguments = GemmArguments; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for batched GEMM in which multiple matrix products are computed +// +// OperationKind: Gemm +// GemmKind: Array + +struct GemmArrayConfiguration { + + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + int batch_count{1}; +}; + +/// Arguments for GEMM - used by all the GEMM operations +struct GemmArrayArguments { + void const * const *A{nullptr}; + void const * const *B{nullptr}; + void const * const *C{nullptr}; + void * const *D{nullptr}; + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Universal GEMM supporting multiple split-K modes, multiple batched modes, real and complex +// +// OperationKind: Gemm +// GemmKind: Universal + +struct GemmUniversalConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int device_count{1}; +}; + +enum class Sm90MixedInputWiderOperand { + A = 0, + B = 1 +}; + +struct GemmUniversalArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + library::RuntimeDatatype runtime_input_datatype_a{}; + library::RuntimeDatatype runtime_input_datatype_b{}; + int swizzle_size{1}; + int split_k_slices{1}; + + // For SM90 mixed input dtype kernels + bool is_sm90_mixed_dtype{false}; + Sm90MixedInputWiderOperand wider_operand{Sm90MixedInputWiderOperand::B}; + bool generate_scale_and_zero{false}; + bool generate_dequantized_AB{false}; + void *Scale{nullptr}; // Scale tensor + void *Zero{nullptr}; // Zero tensor + void *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification + void *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle + void *packed_Scale{nullptr}; // Packed scale for int4 * fp8 + + int device_index{0}; + + bool use_pdl{false}; +}; + +/// Block Scaled GEMM +// +// OperationKind: kBlockScaledGemm +// GemmKind: Universal + +struct BlockScaledGemmArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *SFA{nullptr}; + void const *SFB{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + void *SFD{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + // Needed for ScaleFactor Generation + void const *norm_constant{nullptr}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + int swizzle_size{1}; + int split_k_slices{1}; + + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + + bool use_pdl{false}; +}; + +/// Blockwise GEMM +// +// OperationKind: kBlockwiseGemm +// GemmKind: Universal + +struct BlockwiseGemmArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *SFA{nullptr}; + void const *SFB{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + int sf_m_vec_size{0}; + int sf_n_vec_size{0}; + int sf_k_vec_size{0}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + int swizzle_size{1}; + int split_k_slices{1}; + + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + + bool use_pdl{false}; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Complex valued GEMM in which real and imaginary parts are separated by a stride +// +// OperationKind: Gemm +// GemmKind: Planar complex + +struct GemmPlanarComplexConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + int batch_count{1}; + int64_t lda_real{0}; + int64_t lda_imag{0}; + int64_t ldb_real{0}; + int64_t ldb_imag{0}; + int64_t ldc_real{0}; + int64_t ldc_imag{0}; + int64_t ldd_real{0}; + int64_t ldd_imag{0}; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArguments { + + void const *A_real{nullptr}; + void const *A_imag{nullptr}; + void const *B_real{nullptr}; + void const *B_imag{nullptr}; + void const *C_real{nullptr}; + void const *C_imag{nullptr}; + void *D_real{nullptr}; + void *D_imag{nullptr}; + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A_real{0}; + int64_t batch_stride_A_imag{0}; + int64_t batch_stride_B_real{0}; + int64_t batch_stride_B_imag{0}; + int64_t batch_stride_C_real{0}; + int64_t batch_stride_C_imag{0}; + int64_t batch_stride_D_real{0}; + int64_t batch_stride_D_imag{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This is a special form of planar complex which loads pointers and problem size +/// from memory. +struct GemmPlanarComplexArrayConfiguration { + + gemm::GemmCoord problem_size{}; + int batch_count{1}; + + int64_t lda_real{0}; + int64_t lda_imag{0}; + int64_t ldb_real{0}; + int64_t ldb_imag{0}; + int64_t ldc_real{0}; + int64_t ldc_imag{0}; + int64_t ldd_real{0}; + int64_t ldd_imag{0}; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArrayArguments { + + int const *M{nullptr}; + int const *N{nullptr}; + int const *K{nullptr}; + + void const * const * A_real{nullptr}; + void const * const * A_imag{nullptr}; + void const * const * B_real{nullptr}; + void const * const * B_imag{nullptr}; + void const * const * C_real{nullptr}; + void const * const * C_imag{nullptr}; + void * const * D_real{nullptr}; + void * const * D_imag{nullptr}; + + void const * alpha{nullptr}; + void const * beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Grouped GEMM supporting +// +// OperationKind: Gemm +// GemmKind: Grouped + +struct GemmGroupedConfiguration { + int problem_count{0}; + // GemmGroupedConfiguration is passed to initialize(), which + // is responsible for allocating the device-side stride storage. + int64_t* lda; + int64_t* ldb; + int64_t* ldc; + + cute::Shape* problem_sizes_3x_host; +}; + +struct GemmGroupedArguments { + int problem_count{}; + gemm::GemmCoord* problem_sizes{nullptr}; + + void* ptr_A{nullptr}; + void* ptr_B{nullptr}; + void* ptr_C{nullptr}; + void* ptr_D{nullptr}; + + int64_t* lda{nullptr}; + int64_t* ldb{nullptr}; + int64_t* ldc{nullptr}; + int64_t* ldd{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; + + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + + library::RasterOrder raster_order{}; + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + int swizzle_size{1}; + + // these should really be in the configuration but staying consistent with GEMM + int sm_count{0}; + int max_active_clusters{0}; + + // The user is responsible for allocating storage for problem sizes. + // Since GemmGroupedArguments is used by both the 2.x and 3.x APIs, we + // unfortunately need to have both options in this struct, and the + // underlying operation uses the one it needs. + cute::Shape* problem_sizes_3x; + cute::Shape* problem_sizes_3x_host; +}; + +struct GroupedGemmBlockScaledArguments : GemmGroupedArguments { + void* SFA{nullptr}; + void* SFB{nullptr}; + void* SFD{nullptr}; + void* norm_constant{nullptr}; +}; + +struct GroupedGemmBlockwiseArguments : GemmGroupedArguments { + void* SFA{nullptr}; + void* SFB{nullptr}; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// OperationKind: kSparseGemm +// + +/// Computes GEMM assuming one of the inputs has 2:4 structured sparsity. +struct SparseGemmConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + int batch_count{1}; /// number of sparse matrix products in batch + int64_t lda{0}; /// leading dimension of A operand + int64_t ldb{0}; /// leading dimension of B operand + int64_t ldc{0}; /// leading dimension of C operand + int64_t ldd{0}; /// leading dimension of D operand + int64_t lde{0}; /// leading dimension of E operand (metadata matrix) + int64_t batch_stride_A{0}; // stride between matrices + int64_t batch_stride_B{0}; // stride between matrices + int64_t batch_stride_C{0}; // stride between matrices + int64_t batch_stride_D{0}; // stride between matrices + int64_t batch_stride_E{0}; // stride between matrices +}; + +/// Arguments for sparse GEMMs +struct SparseGemmArguments { + void const *A{nullptr}; /// pointer to A matrix + void const *B{nullptr}; /// pointer to B matrix + void const *C{nullptr}; /// pointer to C matrix + void *D{nullptr}; /// pointer to D matrix + void const *E{nullptr}; /// pointer to E matrix (metadata) + void const *alpha{nullptr}; /// pointer to alpha scalar + void const *beta{nullptr}; /// pointer to beta scalar + ScalarPointerMode pointer_mode{}; /// enumerant indicating whether alpha/beta pointers are host + /// or device pointers. + bool use_pdl{false}; /// Whether to use PDL when launching the kernel +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic Rank K update operations +// +// OperationKind: (Syrk, Herk, Syr2k, Her2k) +// RankKKind: Universal +// +struct RankKConfiguration { + + /// SYRK problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for (Syrk, Herk, Syr2k, Her2k) +struct RankKArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix (used only for Syr2k and Her2k) + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic TRMM operations +// +// OperationKind: Trmm +// TrmmKind: Universal +// +struct TrmmConfiguration { + + /// TRMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for TRMM +struct TrmmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic SYMM/HEMM update operations +// +// OperationKind: (Symm, Hemm) +// SymmKind: Universal +// +struct SymmConfiguration { + + /// SYMM/HEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for (Symm, Hemm) +struct SymmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Two dimensional convolution +// +// OperationKind: Conv2d +// +struct Conv2dConfiguration { + + conv::SplitKMode split_k_mode; + + /// Conv2d problem size + // contains strictly conv2d size (N,H,W,C,K,R,S,P,Q,padding,stride,dilation,mode) + // also includes (split_k_slices, groups) + conv::Conv2dProblemSize problem_size{}; + + // stride of operand A + std::vector stride_a{}; + + // stride of operand B + std::vector stride_b{}; + + // stride of operand C + std::vector stride_c{}; +}; + + +/// Three dimensional convolution +// +// OperationKind: Conv3d +// +struct Conv3dConfiguration { + + conv::SplitKMode split_k_mode{}; + + /// Conv2d problem size + // contains strictly conv2d size (N,D,H,W,C,K,T,R,S,Z,P,Q,padding,stride,dilation,mode) + // also includes (split_k_slices, groups) + conv::Conv3dProblemSize problem_size{}; + + /// Layout object for activations tensor + layout::TensorNDHWC layout_activations{}; + + /// Layout object for filters tensor + layout::TensorNDHWC layout_filters{}; + + /// Layout object for source tensor + layout::TensorNDHWC layout_source{}; + + /// Layout object for output tensor + layout::TensorNDHWC layout_output{}; + + // + // Methods + // + + // Mapping functions (A,B,C -> activation,filter,output) + layout::TensorNDHWC layout_a(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_activations; + case library::ConvKind::kDgrad: return layout_output; + case library::ConvKind::kWgrad: return layout_output; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + layout::TensorNDHWC layout_b(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_filters; + case library::ConvKind::kDgrad: return layout_filters; + case library::ConvKind::kWgrad: return layout_activations; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + layout::TensorNDHWC layout_c(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_output; + case library::ConvKind::kDgrad: return layout_activations; + case library::ConvKind::kWgrad: return layout_filters; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } +}; + +/// Arguments for CONV +struct ConvArguments { + + ///////////////////////////////////////////////////////// + /// ImplicitGemm matrices A, B, C, D + ///////////////////////////////////////////////////////// + /// pointer to implicit gemm matrix A + void const *A{nullptr}; + + /// pointer to implicit gemm matrix B + void const *B{nullptr}; + + /// pointer to reordered matrix B + void const *reordered_B{nullptr}; + + /// pointer to implicit gemm matrix C + void const *C{nullptr}; + + /// pointer to implicit gemm destination matrix D + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for Reduction operations +// +// OperationKind: Reduction +// +struct ReductionConfiguration { + + /// Reduction problem size + MatrixCoord problem_size{}; + + /// Number of partitions to reduce + int partitions{0}; + + /// Number of elements between each partition + int64_t partition_stride{0}; + + /// leading dimension of 'w'orkspace operand + int64_t ldw{0}; + + /// leading dimension of 's'ource operand + int64_t lds{0}; + + /// leading dimension of 'd'estination operand + int64_t ldd{0}; +}; + +/// Arguments for Reduction +struct ReductionArguments { + + /// Pointer to workspace matrix + void const *workspace{nullptr}; + + /// Pointer to source matrix + void const *source{nullptr}; + + /// Pointer to destination matrix + void *destination{nullptr}; + + /// pointer to reference matrix + void *reference{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h new file mode 100644 index 0000000000000000000000000000000000000000..c4fb0ee8ca32124450b1063cc3613078e600479d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h @@ -0,0 +1,114 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Manifest of CUTLASS Library + + This is the root of the data structure containing CUTLASS objects +*/ + +#pragma once + +#include +#include +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "library.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Forward declaration +class Manifest; + +// init and insert all cutlass gemm operations in manifest object (procedurally generated using generator.py) +void initialize_all(Manifest &manifest); + +// init and insert all reduction op in manifest object (manually instantiated in library/reduction) +void initialize_all_reduction_op(Manifest &manifest); + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +/// List of operations +using OperationVector = std::vector>; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Manifest of CUTLASS Library +class Manifest { +private: + + /// Operation provider + Provider provider_; + + /// Global list of operations + OperationVector operations_; + +public: + Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { } + + /// Top-level initialization + Status initialize(); + + /// Used for initialization + void reserve(size_t operation_count); + + /// Graceful shutdown + Status release(); + + /// Appends an operation and takes ownership + void append(Operation *operation_ptr) {\ + // This function is inline s.t. it is present in generated libraries + // without having to compile or link in manifest.cpp + operations_.emplace_back(operation_ptr); + } + + /// Returns an iterator to the first operation + OperationVector const &operations() const; + + /// Returns a const iterator + OperationVector::const_iterator begin() const; + + /// Returns a const iterator + OperationVector::const_iterator end() const; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h new file mode 100644 index 0000000000000000000000000000000000000000..f36232c8dc833e2b24d681686f6662e79b7ecd0a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h @@ -0,0 +1,905 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + \file + \brief Defines a data structure in which a set of functionally equivalent library::Operation + instances may be queried. +*/ + +#pragma once +#include +#include +#include +#include + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct GemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + ComplexTransform transform_A; + NumericTypeID element_B; + LayoutTypeID layout_B; + ComplexTransform transform_B; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + + // + // Methods + // + + inline + GemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + ComplexTransform transform_A = ComplexTransform::kNone, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + ComplexTransform transform_B = ComplexTransform::kNone, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor + ): + provider(provider), + gemm_kind(gemm_kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + transform_A(transform_A), + element_B(element_B), + layout_B(layout_B), + transform_B(transform_B), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D) + { } + + inline + bool operator==(GemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (transform_A == rhs.transform_A) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (transform_B == rhs.transform_B) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D); + } + + inline + bool operator!=(GemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " transform_A: " << to_string(k.transform_A) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " transform_B: " << to_string(k.transform_B) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for GemmFunctionalKey +struct GemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(GemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.element_compute)), 3) ^ + rotl(hash(int(key.element_scalar)), 4) ^ + rotl(hash(int(key.element_A)), 5) ^ + rotl(hash(int(key.layout_A)), 6) ^ + rotl(hash(int(key.transform_A)), 7) ^ + rotl(hash(int(key.element_B)), 8) ^ + rotl(hash(int(key.layout_B)), 9) ^ + rotl(hash(int(key.transform_B)), 10) ^ + rotl(hash(int(key.element_C)), 11) ^ + rotl(hash(int(key.layout_C)), 12) ^ + rotl(hash(int(key.element_D)), 13) ^ + rotl(hash(int(key.layout_D)), 14); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes a partial ordering to search for GEMM operators +struct GemmPreferenceKey { + + int compute_capability; + int alignment; + + // + // Methods + // + + GemmPreferenceKey(): compute_capability(), alignment() { } + + GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } + + bool operator<(GemmPreferenceKey const &rhs) const { + return (compute_capability < rhs.compute_capability) || + ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); + } + + bool operator==(GemmPreferenceKey const &rhs) const { + return compute_capability == rhs.compute_capability; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline +std::ostream& operator<< (std::ostream& out, const cutlass::library::GemmPreferenceKey& key) { + out << "{\n" + << "compute_capability : " << key.compute_capability << std::endl + << "alignment : " << key.alignment << std::endl + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Maps minimum compute capability onto a vector of possible operations +using GemmOperationVectorMap = std::map< + GemmPreferenceKey, + std::vector +>; + +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using GemmOperationFunctionalMap = std::unordered_map< + GemmFunctionalKey, + GemmOperationVectorMap, + GemmFunctionalKeyHasher +>; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for BlockScaled Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct BlockScaledGemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + OperationKind kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + NumericTypeID element_SFA; + NumericTypeID element_B; + LayoutTypeID layout_B; + NumericTypeID element_SFB; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + NumericTypeID element_SFD; + LayoutTypeID layout_SFD; + int SFVecSize; + int EpilogueSFVecSize; + // + // Methods + // + + inline + BlockScaledGemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + OperationKind kind = OperationKind::kBlockScaledGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFA = NumericTypeID::kF16, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFB = NumericTypeID::kF16, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFD = NumericTypeID::kF16, + LayoutTypeID layout_SFD = LayoutTypeID::kRowMajor, + int sf_vec_size = 32 + , int epilogue_sf_vec_size = 32 + ): + provider(provider), + gemm_kind(gemm_kind), + kind(kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + element_SFA(element_SFA), + element_B(element_B), + layout_B(layout_B), + element_SFB(element_SFB), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D), + element_SFD(element_SFD), + layout_SFD(layout_SFD), + SFVecSize(sf_vec_size) + , EpilogueSFVecSize(epilogue_sf_vec_size) + { } + + inline + bool operator==(BlockScaledGemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (kind == rhs.kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_SFA == rhs.element_SFA) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_SFB == rhs.element_SFB) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D) && + (element_SFD == rhs.element_SFD) && + (layout_SFD == rhs.layout_SFD) && + (SFVecSize == rhs.SFVecSize) + && (EpilogueSFVecSize == rhs.EpilogueSFVecSize) + ; + } + + inline + bool operator!=(BlockScaledGemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::BlockScaledGemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " kind: " << to_string(k.kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " element_SFA: " << to_string(k.element_SFA) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " element_SFB: " << to_string(k.element_SFB) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << " element_SFD: " << to_string(k.element_SFD) << "\n" + << " layout_SFD: " << to_string(k.layout_SFD) << "\n" + << " SFVecSize: " << k.SFVecSize << "\n" + << "EpilogueSFVecSize: " << k.EpilogueSFVecSize << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for BlockScaledGemmFunctionalKeyHasher +struct BlockScaledGemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(BlockScaledGemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.kind)), 3) ^ + rotl(hash(int(key.element_compute)), 4) ^ + rotl(hash(int(key.element_scalar)), 5) ^ + rotl(hash(int(key.element_A)), 6) ^ + rotl(hash(int(key.layout_A)), 7) ^ + rotl(hash(int(key.element_SFA)), 8) ^ + rotl(hash(int(key.element_B)), 9) ^ + rotl(hash(int(key.layout_B)), 10) ^ + rotl(hash(int(key.element_SFB)), 11) ^ + rotl(hash(int(key.element_C)), 12) ^ + rotl(hash(int(key.layout_C)), 13) ^ + rotl(hash(int(key.element_D)), 14) ^ + rotl(hash(int(key.layout_D)), 15) ^ + rotl(hash(int(key.element_SFD)), 16) ^ + rotl(hash(int(key.layout_SFD)), 17) ^ + rotl(hash(int(key.SFVecSize)), 18) ^ + rotl(hash(int(key.EpilogueSFVecSize)), 19) + ; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using BlockScaledGemmOperationFunctionalMap = std::unordered_map< + BlockScaledGemmFunctionalKey, + GemmOperationVectorMap, + BlockScaledGemmFunctionalKeyHasher +>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Blockwise Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct BlockwiseGemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + OperationKind kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + NumericTypeID element_SFA; + NumericTypeID element_B; + LayoutTypeID layout_B; + NumericTypeID element_SFB; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + // + // Methods + // + + inline + BlockwiseGemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + OperationKind kind = OperationKind::kBlockwiseGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFA = NumericTypeID::kF16, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFB = NumericTypeID::kF16, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, + int sfm_vec_size = 32, + int sfn_vec_size = 32, + int sfk_vec_size = 32 + ): + provider(provider), + gemm_kind(gemm_kind), + kind(kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + element_SFA(element_SFA), + element_B(element_B), + layout_B(layout_B), + element_SFB(element_SFB), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D), + SFMVecSize(sfm_vec_size), + SFNVecSize(sfn_vec_size), + SFKVecSize(sfk_vec_size) + { } + + inline + bool operator==(BlockwiseGemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (kind == rhs.kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_SFA == rhs.element_SFA) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_SFB == rhs.element_SFB) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D) && + (SFMVecSize == rhs.SFMVecSize) && + (SFNVecSize == rhs.SFNVecSize) && + (SFKVecSize == rhs.SFKVecSize); + } + + inline + bool operator!=(BlockwiseGemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::BlockwiseGemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " kind: " << to_string(k.kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " element_SFA: " << to_string(k.element_SFA) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " element_SFB: " << to_string(k.element_SFB) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << " SFMVecSize: " << k.SFMVecSize << "\n" + << " SFNVecSize: " << k.SFNVecSize << "\n" + << " SFKVecSize: " << k.SFKVecSize << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for BlockwiseGemmFunctionalKeyHasher +struct BlockwiseGemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(BlockwiseGemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.kind)), 3) ^ + rotl(hash(int(key.element_compute)), 4) ^ + rotl(hash(int(key.element_scalar)), 5) ^ + rotl(hash(int(key.element_A)), 6) ^ + rotl(hash(int(key.layout_A)), 7) ^ + rotl(hash(int(key.element_SFA)), 8) ^ + rotl(hash(int(key.element_B)), 9) ^ + rotl(hash(int(key.layout_B)), 10) ^ + rotl(hash(int(key.element_SFB)), 11) ^ + rotl(hash(int(key.element_C)), 12) ^ + rotl(hash(int(key.layout_C)), 13) ^ + rotl(hash(int(key.element_D)), 14) ^ + rotl(hash(int(key.layout_D)), 15) ^ + rotl(hash(int(key.SFMVecSize)), 16) ^ + rotl(hash(int(key.SFNVecSize)), 17) ^ + rotl(hash(int(key.SFKVecSize)), 18) + ; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using BlockwiseGemmOperationFunctionalMap = std::unordered_map< + BlockwiseGemmFunctionalKey, + GemmOperationVectorMap, + BlockwiseGemmFunctionalKeyHasher +>; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Conv Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying conv2d functional behavior +struct ConvFunctionalKey { + library::Provider provider; + library::ConvKind conv_kind; + library::NumericTypeID element_A; + library::LayoutTypeID layout_A; + library::NumericTypeID element_B; + library::LayoutTypeID layout_B; + library::NumericTypeID element_C; + library::LayoutTypeID layout_C; + library::NumericTypeID element_accumulator; + library::NumericTypeID element_compute; + + + // + // Methods + // + + inline + ConvFunctionalKey( + library::Provider provider = library::Provider::kInvalid, + library::ConvKind conv_kind = library::ConvKind::kFprop, + library::NumericTypeID element_A = library::NumericTypeID::kF16, + library::LayoutTypeID layout_A = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_B = library::NumericTypeID::kF16, + library::LayoutTypeID layout_B = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_C = library::NumericTypeID::kF16, + library::LayoutTypeID layout_C = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, + library::NumericTypeID element_compute = library::NumericTypeID::kF32 + ): + provider(provider), + conv_kind(conv_kind), + element_A(element_A), + layout_A(layout_A), + element_B(element_B), + layout_B(layout_B), + element_C(element_C), + layout_C(layout_C), + element_accumulator(element_accumulator), + element_compute(element_compute) + { } + + inline + bool operator==(ConvFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (conv_kind == rhs.conv_kind) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_accumulator == rhs.element_accumulator) && + (element_compute == rhs.element_compute); + } + + inline + bool operator!=(ConvFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream& operator<< (std::ostream& out, const cutlass::library::ConvFunctionalKey& key) { + out << "{\n" + << "provider: " << to_string(key.provider) << std::endl + << "conv_kind: " << to_string(key.conv_kind) << std::endl + << "element_A: " << to_string(key.element_A) << std::endl + << "layout_A: " << to_string(key.layout_A) << std::endl + << "element_B: " << to_string(key.element_B) << std::endl + << "layout_B: " << to_string(key.layout_B) << std::endl + << "element_C: " << to_string(key.element_C) << std::endl + << "layout_C: " << to_string(key.layout_C) << std::endl + << "element_accumulator: " << to_string(key.element_accumulator) << std::endl + << "element_compute: " << to_string(key.element_compute) << std::endl + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +struct ConvFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(ConvFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.conv_kind)), 2) ^ + rotl(hash(int(key.element_A)), 3) ^ + rotl(hash(int(key.layout_A)), 4) ^ + rotl(hash(int(key.element_B)), 5) ^ + rotl(hash(int(key.layout_B)), 6) ^ + rotl(hash(int(key.element_C)), 7) ^ + rotl(hash(int(key.layout_C)), 8) ^ + rotl(hash(int(key.element_accumulator)), 9) ^ + rotl(hash(int(key.element_compute)), 10); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes a partial ordering to search for Conv2d operators +struct ConvPreferenceKey { + + int compute_capability; + IteratorAlgorithmID iterator_algorithm; + + + // + // Methods + // + + ConvPreferenceKey(): compute_capability(), iterator_algorithm() { } + + ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): + compute_capability(cc), iterator_algorithm(iterator_algorithm) { } + + bool operator<(ConvPreferenceKey const &rhs) const { + return (compute_capability < rhs.compute_capability) || + ((compute_capability == rhs.compute_capability) && (iterator_algorithm < rhs.iterator_algorithm)); + } + + bool operator==(ConvPreferenceKey const &rhs) const { + return (compute_capability == rhs.compute_capability) && + (iterator_algorithm == rhs.iterator_algorithm); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Maps minimum compute capability onto a vector of possible operations +using ConvOperationVectorMap = std::map< + ConvPreferenceKey, + std::vector +>; + +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using ConvOperationFunctionalMap = std::unordered_map< + ConvFunctionalKey, + ConvOperationVectorMap, + ConvFunctionalKeyHasher +>; +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Tuple uniquely identifying conv2d functional behavior +struct ReductionFunctionalKey { + library::Provider provider; + library::NumericTypeID element_workspace; + library::NumericTypeID element_accumulator; + library::NumericTypeID element_output; + library::NumericTypeID element_compute; + library::MathOperationID reduce_math_op; + library::EpilogueKind epilogue_math_op; + + + // + // Methods + // + + inline + ReductionFunctionalKey( + library::Provider provider = library::Provider::kInvalid, + library::NumericTypeID element_workspace = library::NumericTypeID::kF16, + library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, + library::NumericTypeID element_output = library::NumericTypeID::kF16, + library::NumericTypeID element_compute = library::NumericTypeID::kF32, + library::MathOperationID reduce_math_op = library::MathOperationID::kAdd, + library::EpilogueKind epilogue_math_op = library::EpilogueKind::kLinearCombination + ): + provider(provider), + element_workspace(element_workspace), + element_accumulator(element_accumulator), + element_output(element_output), + element_compute(element_compute), + reduce_math_op(reduce_math_op), + epilogue_math_op(epilogue_math_op) + { } + + inline + bool operator==(ReductionFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (element_workspace == rhs.element_workspace) && + (element_accumulator == rhs.element_accumulator) && + (element_output == rhs.element_output) && + (element_compute == rhs.element_compute) && + (reduce_math_op == rhs.reduce_math_op) && + (epilogue_math_op == rhs.epilogue_math_op); + } + + inline + bool operator!=(ReductionFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +struct ReductionFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(ReductionFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.element_workspace)), 2) ^ + rotl(hash(int(key.element_accumulator)), 3) ^ + rotl(hash(int(key.element_output)), 4) ^ + rotl(hash(int(key.element_compute)), 5) ^ + rotl(hash(int(key.reduce_math_op)), 6) ^ + rotl(hash(int(key.epilogue_math_op)), 7); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline +std::ostream& operator<< (std::ostream& out, const ReductionFunctionalKey& key) { + out << "{\n" + << "provider: " << library::to_string(key.provider) << std::endl + << "element_workspace : " << library::to_string(key.element_workspace) << std::endl + << "element_accumulator : " << library::to_string(key.element_accumulator) << std::endl + << "element_output : " << library::to_string(key.element_output) << std::endl + << "element_compute : " << library::to_string(key.element_compute) << std::endl + << "}"; + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ReductionOperationFunctionalMap has NO preference key and a single instance per functional key +// i.e. only one tile size configuration per functional key +using ReductionOperationFunctionalMap = std::unordered_map< + ReductionFunctionalKey, + library::Operation const *, + ReductionFunctionalKeyHasher +>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Table of cutlass::library::Operation instances +class OperationTable { +public: + + /// Map of all operations of type kGemm + // provider (kCUTLASS) + GemmOperationFunctionalMap gemm_operations; + + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + BlockScaledGemmOperationFunctionalMap block_scaled_gemm_operations; + + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + BlockwiseGemmOperationFunctionalMap blockwise_gemm_operations; + + /// Map of all operations of type kConv2d + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + ConvOperationFunctionalMap conv2d_operations; + + /// Map of all operations of type kConv3d + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + ConvOperationFunctionalMap conv3d_operations; + + /// Map of all operations of type kConv2d + // provider (kCUTLASS) + ReductionOperationFunctionalMap reduction_operations; + +public: + + void append(Manifest const &manifest); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k); diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h new file mode 100644 index 0000000000000000000000000000000000000000..9a757433f38fbf10d9a352e07c7f3084a99e4098 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h @@ -0,0 +1,68 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/operation_table.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Singleton instance stores a Manifest and Operation table +class Singleton { +public: + + /// Manifest object + Manifest manifest; + + /// Operation table referencing the Manifest + OperationTable operation_table; + +public: + + Singleton(); + + static Singleton const &get(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h new file mode 100644 index 0000000000000000000000000000000000000000..9f8c4ff13ba543b4ec63997ba55e9278bfb357a6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h @@ -0,0 +1,295 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + #pragma once + + ///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Layout type identifier +enum class LayoutTypeID { + kUnknown, + kColumnMajor, + kRowMajor, + kBlockScalingTensor, + kColumnMajorInterleavedK2, + kRowMajorInterleavedK2, + kColumnMajorInterleavedK4, + kRowMajorInterleavedK4, + kColumnMajorInterleavedK16, + kRowMajorInterleavedK16, + kColumnMajorInterleavedK32, + kRowMajorInterleavedK32, + kColumnMajorInterleavedK64, + kRowMajorInterleavedK64, + kTensorNCHW, + kTensorNCDHW, + kTensorNHWC, + kTensorNDHWC, + kTensorNC32HW32, + kTensorC32RSK32, + kTensorNC64HW64, + kTensorC64RSK64, + kInvalid +}; + +/// Numeric data type +enum class NumericTypeID { + kUnknown, + kVoid, + kB1, + kU2, + kU4, + kU8, + kU16, + kU32, + kU64, + kS2, + kS4, + kS8, + kS16, + kS32, + kS64, + kFE4M3, + kFE5M2, + + kFE2M3, + kFE3M2, + kFE2M1, + kFUE8M0, + kFUE4M3, + kF8, + kF6, + kF4, + + kF16, + kBF16, + kTF32, + kF32, + kF64, + kCF16, + kCBF16, + kCF32, + kCTF32, + kCF64, + kCS2, + kCS4, + kCS8, + kCS16, + kCS32, + kCS64, + kCU2, + kCU4, + kCU8, + kCU16, + kCU32, + kCU64, + kInvalid +}; + +/// Enumerated type describing a transformation on a complex value. +enum class ComplexTransform { + kNone, + kConjugate, + kInvalid +}; + +/// Providers +enum class Provider { + kNone, + kCUTLASS, + kReferenceHost, + kReferenceDevice, + kCUBLAS, + kCUDNN, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating the kind of operation +enum class OperationKind { + kGemm, + kBlockScaledGemm, + kBlockwiseGemm, + kRankK, + kRank2K, + kTrmm, + kSymm, + kConv2d, + kConv3d, + kEqGemm, + kSparseGemm, + kReduction, + kGroupedGemm, + kInvalid +}; + +/// Enumeration indicating whether scalars are in host or device memory +enum class ScalarPointerMode { + kHost, + kDevice, + kInvalid +}; + +/// Describes how reductions are performed across threadblocks +enum class SplitKMode { + kNone, + kSerial, + kParallel, + kParallelSerial, + kInvalid +}; + +/// Indicates the classificaition of the math instruction +enum class OpcodeClassID { + kSimt, + kTensorOp, + kWmmaTensorOp, + kSparseTensorOp, + kBlockScaledOp, + kInvalid +}; + +enum class MathOperationID { + kAdd, + kMultiplyAdd, + kMultiplyAddSaturate, + kMultiplyAddMixedInputUpcast, + kMultiplyAddFastBF16, + kMultiplyAddFastF16, + kMultiplyAddFastF32, + kMultiplyAddComplex, + kMultiplyAddComplexFastF32, + kMultiplyAddGaussianComplex, + kXorPopc, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating what kind of GEMM operation to perform +enum class GemmKind { + kGemm, + kBlockScaledGemm, + kSparse, + kUniversal, + kPlanarComplex, + kPlanarComplexArray, + kGrouped, + kInvalid +}; + +/// Enumeration indicating what kind of RankK update operation to perform +enum class RankKKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of TRMM operation to perform +enum class TrmmKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of SYMM/HEMM operation to perform +enum class SymmKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of Conv2d operation to perform +enum class ConvKind { + kUnknown, + kFprop, + kDgrad, + kWgrad, + kInvalid +}; + +enum class ConvModeID { + kCrossCorrelation, + kConvolution, + kInvalid +}; + +// Iterator algorithm enum in order of general performance-efficiency +enum class IteratorAlgorithmID { + kNone, + kAnalytic, + kOptimized, + kFixedChannels, + kFewChannels, + kInvalid +}; + + +enum class EpilogueKind { + kUnknown, + kConversion, + kLinearCombination, + kLinearCombinationClamp, + kLinearCombinationPlanarComplex, + kLinearCombinationRelu, + kLinearCombinationSigmoid, + kInvalid +}; + + +enum class RuntimeDatatype { + kStatic, + kE4M3, + kE5M2, + kE3M2, + kE2M3, + kE2M1, + + kInvalid +}; + + +enum class RasterOrder { + kAlongN, + kAlongM, + kHeuristic, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h new file mode 100644 index 0000000000000000000000000000000000000000..f537421751c1f2af3b95a2e1951006af441b28e0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h @@ -0,0 +1,281 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief Utilities accompanying the CUTLASS library for interacting with Library types. +*/ + +#ifndef CUTLASS_LIBRARY_UTIL_H +#define CUTLASS_LIBRARY_UTIL_H + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexical cast from string +template T from_string(std::string const &); + +/// Converts a Provider enumerant to a string +char const *to_string(Provider provider, bool pretty = false); + +/// Parses a Provider enumerant from a string +template <> Provider from_string(std::string const &str); + +/// Converts a GemmKind enumerant to a string +char const *to_string(GemmKind type, bool pretty = false); + +/// Converts a RankKKind enumerant to a string +char const *to_string(RankKKind type, bool pretty = false); + +/// Converts a TrmmKind enumerant to a string +char const *to_string(TrmmKind type, bool pretty = false); + +/// Converts a SymmKind enumerant to a string +char const *to_string(SymmKind type, bool pretty = false); + +/// Converts a SideMode enumerant to a string +char const *to_string(SideMode type, bool pretty = false); + +/// Converts a FillMode enumerant to a string +char const *to_string(FillMode type, bool pretty = false); + +/// Converts a BlasMode enumerant to a string +char const *to_string(BlasMode type, bool pretty = false); + +/// Converts a DiagType enumerant to a string +char const *to_string(DiagType type, bool pretty = false); + +/// Converts a NumericType enumerant to a string +char const *to_string(OperationKind type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> OperationKind from_string(std::string const &str); + +/// Converts a NumericType enumerant to a string +char const *to_string(NumericTypeID type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> NumericTypeID from_string(std::string const &str); + +/// Returns the size of a data type in bits +int sizeof_bits(NumericTypeID type); + +/// Returns true if the numeric type is a complex data type or false if real-valued. +bool is_complex_type(NumericTypeID type); + +/// Returns the real-valued type underlying a type (only different from 'type' if complex) +NumericTypeID get_real_type(NumericTypeID type); + +/// Returns true if numeric type is integer +bool is_integer_type(NumericTypeID type); + +/// Returns true if numeric type is signed +bool is_signed_type(NumericTypeID type); + +/// Returns true if numeric type is a signed integer +bool is_signed_integer(NumericTypeID type); + +/// returns true if numeric type is an unsigned integer +bool is_unsigned_integer(NumericTypeID type); + +/// Returns true if numeric type is floating-point type +bool is_float_type(NumericTypeID type); + +/// To string method for cutlass::Status +char const *to_string(Status status, bool pretty = false); + +/// Converts a LayoutTypeID enumerant to a string +char const *to_string(LayoutTypeID layout, bool pretty = false); + +/// Parses a LayoutType enumerant from a string +template <> LayoutTypeID from_string(std::string const &str); + +/// Returns the rank of a layout's stride base on the LayoutTypeID +int get_layout_stride_rank(LayoutTypeID layout_id); + +/// Converts a OpcodeClassID enumerant to a string +char const *to_string(OpcodeClassID type, bool pretty = false); + +/// Converts a OpcodeClassID enumerant from a string +template <> +OpcodeClassID from_string(std::string const &str); + +/// Converts a ComplexTransform enumerant to a string +char const *to_string(ComplexTransform type, bool pretty = false); + +/// Converts a ComplexTransform enumerant from a string +template <> +ComplexTransform from_string(std::string const &str); + + +/// Converts a SplitKMode enumerant to a string +char const *to_string(SplitKMode split_k_mode, bool pretty = false); + +/// Converts a SplitKMode enumerant from a string +template <> +SplitKMode from_string(std::string const &str); + +/// Converts a ConvModeID enumerant to a string +char const *to_string(ConvModeID type, bool pretty = false); + +/// Converts a ConvModeID enumerant from a string +template <> +ConvModeID from_string(std::string const &str); + +/// Converts a IteratorAlgorithmID enumerant to a string +char const *to_string(IteratorAlgorithmID type, bool pretty = false); + +/// Converts a IteratorAlgorithmID enumerant from a string +template <> +IteratorAlgorithmID from_string(std::string const &str); + +/// Converts a ConvKind enumerant to a string +char const *to_string(ConvKind type, bool pretty = false); + +/// Converts a ConvKind enumerant from a string +template <> +ConvKind from_string(std::string const &str); + + +/// Converts a RuntimeDatatype enumerant to a string +char const *to_string(cutlass::library::RuntimeDatatype type, bool pretty = false); + +/// Convers a RuntimeDatatype enumerant from a string +template<> +cutlass::library::RuntimeDatatype from_string(std::string const &str); + + +/// Converts a RasterOrder enumerant to a string +char const *to_string(RasterOrder type, bool pretty = false); + +/// Convers a RasterOrder enumerant from a string +template<> +RasterOrder from_string(std::string const &str); + +/// Converts a bool to a string +char const *to_string(bool type, bool pretty = false); + +/// Convers a bool from a string +template<> +bool from_string(std::string const &str); + +/// Lexical cast from int64_t to string +std::string lexical_cast(int64_t int_value); + +/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. +bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); + +/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. +std::string lexical_cast(std::vector &bytes, NumericTypeID type); + +/// Casts from a signed int64 to the destination type. Returns true if successful. +bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); + +/// Casts from an unsigned int64 to the destination type. Returns true if successful. +bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); + +/// Casts from a real value represented as a double to the destination type. Returns true if successful. +bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); + +NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type); + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = (call); \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(err) << " in " << __func__ << " at " \ + << __FILE__ << ":" << __LINE__ << std::endl; \ + return Status::kInvalid; \ + } \ + } while (0) + +// RAII CUDA buffer container +class CudaBuffer { +public: + CudaBuffer() : size_(0), d_ptr_(nullptr) {} + + explicit CudaBuffer(size_t size) : size_(size), d_ptr_(nullptr) { + cudaError_t err = cudaMalloc(&d_ptr_, size_); + if (err != cudaSuccess) { + throw std::runtime_error("cudaMalloc failed: " + std::string(cudaGetErrorString(err))); + } + } + + ~CudaBuffer() { + if (d_ptr_) { + cudaFree(d_ptr_); + } + } + + CudaBuffer(CudaBuffer const&) = delete; + CudaBuffer& operator=(CudaBuffer const&) = delete; + + CudaBuffer(CudaBuffer&& other) noexcept : size_(other.size_), d_ptr_(other.d_ptr_) { + other.d_ptr_ = nullptr; + other.size_ = 0; + } + + CudaBuffer& operator=(CudaBuffer&& other) noexcept { + if (this != &other) { + if (d_ptr_) { + cudaFree(d_ptr_); + } + d_ptr_ = other.d_ptr_; + size_ = other.size_; + other.d_ptr_ = nullptr; + other.size_ = 0; + } + return *this; + } + + void* data() const noexcept { return d_ptr_; } + size_t size() const noexcept { return size_; } + +private: + size_t size_; + void* d_ptr_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c96b9a2212b42c191551ea70da3ac3baecbed487 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "gemm_operation_3x.hpp" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::CollectiveMainloop::ElementA; + using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::CollectiveMainloop::ElementB; + using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + constexpr static int SFVecSize = TiledMma::SFVecSize; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using Sm1xxBlkScaledConfig = typename CollectiveMainloop::Sm1xxBlkScaledConfig; + + static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; + static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; + using ElementSFD = cute::conditional_t; + using LayoutSFD = cute::conditional_t; + + + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB; + + +private: + BlockScaledGemmDescription description_; + +public: + + /// Constructor + BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + description_.kind = OperationKind::kBlockScaledGemm; + description_.SFA.element = NumericTypeMap::kId; + description_.SFA.layout = LayoutTypeID::kRowMajor; + description_.SFA.alignment = 128; + description_.SFA.log_extent_range = 32; + description_.SFA.log_stride_range = 32; + + description_.SFB.element = NumericTypeMap::kId; + description_.SFB.layout = LayoutTypeID::kRowMajor; + description_.SFB.alignment = 128; + description_.SFB.log_extent_range = 32; + description_.SFB.log_stride_range = 32; + + description_.SFVecSize = SFVecSize; + + description_.SFD = make_TensorDescription(128); + description_.EpilogueSFVecSize = SFD_VectorSize; + + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.gemm_kind = GemmKind::kUniversal; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + BlockScaledGemmDescription const& get_gemm_description() const { + return description_; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) { + + if constexpr (epilogue_scalefactor_generation) { + fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); + fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); + } + + + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + BlockScaledGemmArguments const *arguments) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; + + static_assert(cute::is_same_v, + "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); + using RuntimeDatatypeArg = RuntimeDataTypeA; + + auto mapping = [](RuntimeDatatype type) { + if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE3M2) { + return cute::UMMA::MXF8F6F4Format::E3M2; + } else if (type == RuntimeDatatype::kE2M3) { + return cute::UMMA::MXF8F6F4Format::E2M3; + } else if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF8F6F4Format::E2M1; + } else { + assert("Invalid input datatype."); + } + } + else if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF4Format::E2M1; + } else { + assert("Invalid input datatype."); + } + } + // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype + CUTE_GCC_UNREACHABLE; + }; + + operator_args.mainloop.runtime_data_type_a = mapping(arguments->runtime_input_datatype_a); + operator_args.mainloop.runtime_data_type_b = mapping(arguments->runtime_input_datatype_b); + + } + else { + + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + operator_args.mainloop.layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); + operator_args.mainloop.layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + BlockScaledGemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..00347a993e29035e58401e69698267045b399f7d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp @@ -0,0 +1,429 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "gemm_operation_3x.hpp" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BlockwiseGemmUniversal3xOperation : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::CollectiveMainloop::ElementA; + using ElementSFA = typename Operator::ElementAccumulator; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::CollectiveMainloop::ElementB; + using ElementSFB = typename Operator::ElementAccumulator; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + +private: + BlockwiseGemmDescription description_; + +public: + + /// Constructor + BlockwiseGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + description_.kind = OperationKind::kBlockwiseGemm; + description_.SFA.element = NumericTypeMap::kId; + description_.SFA.layout = size<0,1>(typename CollectiveMainloop::LayoutSFA{}.stride()) == 1 ? + LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + description_.SFA.alignment = CollectiveMainloop::AlignmentSFA; + description_.SFA.log_extent_range = 32; + description_.SFA.log_stride_range = 32; + + description_.SFB.element = NumericTypeMap::kId; + description_.SFB.layout = size<0,1>(typename CollectiveMainloop::LayoutSFB{}.stride()) == 1 ? + LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor; + description_.SFB.alignment = CollectiveMainloop::AlignmentSFA; + description_.SFB.log_extent_range = 32; + description_.SFB.log_stride_range = 32; + + description_.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM; + description_.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN; + description_.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK; + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.gemm_kind = GemmKind::kUniversal; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + BlockwiseGemmDescription const& get_gemm_description() const { + return description_; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockwiseGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockwiseGemmArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + BlockwiseGemmArguments const *arguments) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + operator_args.mainloop.layout_SFA = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); + operator_args.mainloop.layout_SFB = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + BlockwiseGemmArguments const *arguments = + static_cast(arguments_ptr); + + if (arguments->sf_m_vec_size != description_.SFMVecSize && arguments->sf_m_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + if (arguments->sf_n_vec_size != description_.SFNVecSize && arguments->sf_n_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + if (arguments->sf_k_vec_size != description_.SFKVecSize && arguments->sf_k_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..3b1a1584db92c4379e04c84a2658f79313b3eaad --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h @@ -0,0 +1,650 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" +#include "cutlass/conv/kernel/default_depthwise_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/kernel/default_conv2d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Conv2dOperationBase : public Operation { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ConvDescription description_; + +public: + + /// Constructor + Conv2dOperationBase(char const *name = "unknown_conv2d") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kConv2d; + description_.conv_dim = Operator::kConvDim; + + description_.iterator_algorithm = IteratorAlgorithmMap::kId; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::UnderlyingKernel::WarpCount::kM, + Operator::UnderlyingKernel::WarpCount::kN, + Operator::UnderlyingKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + + // TODO: Add split k mode Serial and parallel to convolutions + // description_.split_k_mode = Operator::kSplitK ? SplitKMode::kSerial : SplitKMode::kNone; + + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Conv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Conv2dOperation : public Conv2dOperationBase { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + Conv2dOperation(char const *name = "unknown_conv2d_fprop") : Conv2dOperationBase(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv2dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv2dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv2dOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << "}" << std::endl; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// DirectConv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DirectConv2dOperation : public Conv2dOperation { +public: + + using Operator = Operator_; + using Base = Conv2dOperation; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + DirectConv2dOperation(char const *name = "unknown_direct)conv2d_fprop") : Conv2dOperation(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv2dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_reordered_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + operator_args.ref_reordered_B.reset(static_cast(const_cast(arguments->reordered_B))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv2dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv2dOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << "}" << std::endl; + } +}; + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..fe402c4494c27a882bf42f867a708e954ee87dc0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h @@ -0,0 +1,389 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv3d_fprop.h" +#include "cutlass/conv/kernel/default_conv3d_dgrad.h" +#include "cutlass/conv/kernel/default_conv3d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Conv3dOperationBase : public Operation { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ConvDescription description_; + +public: + + /// Constructor + Conv3dOperationBase(char const *name = "unknown_conv3d") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kConv3d; + description_.conv_dim = Operator::kConvDim; + + description_.iterator_algorithm = IteratorAlgorithmMap::kId; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::UnderlyingKernel::WarpCount::kM, + Operator::UnderlyingKernel::WarpCount::kN, + Operator::UnderlyingKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Conv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Conv3dOperation : public Conv3dOperationBase { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + Conv3dOperation(char const *name = "unknown_conv3d_fprop") : Conv3dOperationBase(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv3dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv3dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv3dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv3dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv3dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv3dOperation::OperatorArguments" << std::endl + << " problem_size: " + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << ", " + << operator_args.ref_A.stride(3) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << ", " + << operator_args.ref_B.stride(3) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << ", " + << operator_args.ref_C.stride(3) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << ", " + << operator_args.ref_D.stride(3) << "}" << std::endl; + } +}; + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86c1513e9c934c22e281cf37e1c5e7783e23d305 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp @@ -0,0 +1,980 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include +#include +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) +#include +#endif + +namespace cutlass::library { + +namespace detail { + +template +constexpr cute::array +vector_to_array_strides_helper(const std::vector& v, + std::index_sequence) +{ + return {v[(sizeof...(Indices) - 1u) - Indices]..., ValueType(1)}; +} + +template +cute::array +vector_to_array_strides(const std::vector& v, std::integral_constant) +{ + static_assert(Size != 0); + CUTLASS_ASSERT(v.size() + 1u == Size); + return vector_to_array_strides_helper(v, std::make_index_sequence{}); +} + +template +constexpr cute::array +coord_to_array_strides_helper( + const ::cutlass::Coord coord, + std::index_sequence) +{ + return {int64_t(coord[(sizeof...(Indices) - 1u) - Indices])..., int64_t(1)}; +} + +template +cute::array +coord_to_array_strides(const ::cutlass::Coord& coord) +{ + static_assert(Rank >= 0); + return coord_to_array_strides_helper(coord, std::make_index_sequence{}); +} + +} // namespace detail + +// Tells the profiler about CUTLASS 3's 2-D and 3-D convolutions. +// For CUTLASS 2's 2-D convolutions, see Conv2dOperation. +// For CUTLASS 2's 3-D convolutions, see Conv3dOperation. +template +class ConvOperation3x : public Operation { +public: + using Operator = Operator_; + + static_assert(Operator::NumSpatialDimensions == 2 || + Operator::NumSpatialDimensions == 3, + "The profiler currently only supports convolutions with 2 or 3 spatial dimensions."); + using LayoutA = cute::conditional_t + >; + using LayoutB = LayoutA; + using LayoutC = LayoutA; + + using ElementA = typename Operator::ElementA; + using ElementB = typename Operator::ElementB; + using ElementC = typename Operator::ElementC; + using ElementD = typename Operator::ElementD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + ConvOperation3x(const char* name = "unknown_cutlass_3_conv") { + // Initialize OperationDescription (the base class) + description_.name = name; + description_.provider = Provider::kCUTLASS; + + if constexpr (Operator::NumSpatialDimensions == 2) { + description_.kind = OperationKind::kConv2d; + } + else if constexpr (Operator::NumSpatialDimensions == 3) { + description_.kind = OperationKind::kConv3d; + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationID::kMultiplyAdd; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + // Initialize ConvDescription (the subclass) + + // kConvDim does not exist in Operator for CUTLASS 3 convolutions. + // For CUTLASS 2 convolutions, it is the number of spatial dimensions. + description_.conv_dim = Operator::NumSpatialDimensions; + description_.conv_kind = ConvKindMap::kId; + + description_.iterator_algorithm = {}; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + } + + ~ConvOperation3x() override = default; + + OperationDescription const& description() const override { + return static_cast(description_); + } + +private: + Status update_operator_arguments_from_configuration_2d_or_3d( + typename Operator::Arguments& out_args, + void const* configuration) const { + Status status = Status::kInvalid; + + CUTLASS_ASSERT(configuration != nullptr); + + if constexpr (Operator::NumSpatialDimensions == 2) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d); + // tools/library/include/cutlass/library/library.h + // defines Conv2dConfiguration. + // tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h + // uses Conv2dConfiguration. + auto* conf_ptr = reinterpret_cast(configuration); + status = update_operator_arguments_from_configuration(out_args, *conf_ptr); + } + else if constexpr (Operator::NumSpatialDimensions == 3) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d); + auto* conf_ptr = reinterpret_cast(configuration); + status = update_operator_arguments_from_configuration(out_args, *conf_ptr); + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + + return status; + } + +public: + Status can_implement( + void const* configuration, + void const* arguments) const override { + Status status = Status::kInvalid; + + // gemm_operation_3x.hpp accesses "configuration" as + // GemmUniversalConfiguration (which lives in + // tools/library/include/cutlass/library/library.h) and + // "arguments" as GemmUniversalArguments (which lives in + // tools/library/include/cutlass/library/library.h). + // Those things don't apply to convolutions. + // Despite the existence of ConvUniversal, there's no + // corresponding "ConvUniversalConfiguration" or + // "ConvUniversalArguments." + + CUTLASS_ASSERT(configuration != nullptr); + CUTLASS_ASSERT(arguments != nullptr); + + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_configuration_2d_or_3d failed"); + return status; + } + + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_arguments failed"); + return status; + } + + return Operator::can_implement(out_args); + } + + uint64_t get_host_workspace_size(void const* /* configuration */) const override { + return sizeof(Operator); + } + + uint64_t get_device_workspace_size( + void const* configuration, + void const* arguments = nullptr) const override + { + // This presumes that at least one of configuration or arguments is nonnull. + Status status = Status::kInvalid; + + // gemm_operation_3x.hpp has get_device_workspace_size return 0 on + // error. It's not clear that this is what we want -- perhaps we + // should return something like expected? -- but + // it's the only option that preserves the current interface. + constexpr uint64_t error_indication = 0; + + typename Operator::Arguments out_args{}; + if (configuration != nullptr) { + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + return error_indication; + } + } + if (arguments != nullptr) { + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + return error_indication; + } + } + + if (status == Status::kSuccess) { + return static_cast(Operator::get_workspace_size(out_args)); + } + else { + return error_indication; + } + } + + Status initialize( + void const* configuration, + void* host_workspace, + void* /* device_workspace */ = nullptr, + cudaStream_t stream = nullptr) const override + { + Status status = Status::kInvalid; + + if (configuration == nullptr) { + CUTLASS_TRACE_HOST("Input configuration is null."); + return Status::kInvalid; + } + + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + // Any kind of failure invalidates the last successful configuration. + clear_last_successful_config(); + return status; + } + else { + set_last_successful_config(configuration); + } + + if (host_workspace == nullptr) { + CUTLASS_TRACE_HOST("host_workspace is null."); + return Status::kInvalid; + } + (void) new (host_workspace) Operator; + return status; + + // CUTLASS 2 convolutions call the Operator's initialize function + // here, like this. + // + //return op->initialize(args, device_workspace, stream); + // + // CUTLASS 3 convolutions (ConvUniversal), like CUTLASS 3 Gemms + // (GemmUniversal), lack an "initialize" member function. + } + + Status run( + void const* arguments, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override + { + auto status = Status::kInvalid; + + // The Operator doesn't appear to save the last configuration (it + // doesn't have a way to do that, since it lacks an initialize() + // member function), so we have to use the stored configuration + // from the last successful initialize() call (if any). + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_stored_configuration(out_args); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("Updating from previous successful configuration failed."); + return status; + } + + if (arguments == nullptr) { + CUTLASS_TRACE_HOST("Input argument 'arguments' is null."); + return Status::kInvalid; + } + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + return status; + } + + auto* op = reinterpret_cast(host_workspace); + return op->run(out_args, device_workspace, stream, nullptr, in_args_ptr->use_pdl); + } + +private: + ConvDescription description_; + // Result of initialize() calling + // update_operator_arguments_from_configuration() successfully. + // This is needed because run() doesn't take a configuration, just + // arguments, and the kernel doesn't appear to save the + // configuration from the last initialize() call. + // + // Unfortunately, this must be declared mutable, because it must be + // set in initialize(), and initialize() is inherited as const. + mutable std::variant< + std::monostate, + Conv2dConfiguration, + Conv3dConfiguration> last_successful_config_{std::monostate{}}; + + // Clear the last configuration resulting from a successful initialize() call. + // + // Unfortunately, this must be declared const, because initialize() is. + void clear_last_successful_config() const { + last_successful_config_ = std::monostate{}; + } + + // Set the last configuration resulting from a successful initialize() call. + // + // Unfortunately, this must be declared const, because initialize() is. + void set_last_successful_config(void const* configuration) const { + CUTLASS_ASSERT(configuration != nullptr); + + if constexpr (Operator::NumSpatialDimensions == 2) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d); + auto* conf_ptr = reinterpret_cast(configuration); + last_successful_config_ = *conf_ptr; + } else if constexpr (Operator::NumSpatialDimensions == 3) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d); + auto* conf_ptr = reinterpret_cast(configuration); + last_successful_config_ = *conf_ptr; + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + } + + // Whether a configuration from a successful initialize() call exists. + bool last_successful_config_exists() const { + return not std::holds_alternative(last_successful_config_); + } + + // Visitor for update_operator_arguments_from_stored_configuration. + struct ConfigurationVisitor { + typename Operator::Arguments& out_args; + + Status operator() (std::monostate const&) const { + CUTLASS_TRACE_HOST("No successful previous configuration exists. " + "One cause is calling run() before a successful initialize() call."); + return Status::kInvalid; + } + Status operator() (Conv2dConfiguration const& conf2d) const { + return update_operator_arguments_from_configuration(out_args, conf2d); + } + Status operator() (Conv3dConfiguration const& conf3d) const { + return update_operator_arguments_from_configuration(out_args, conf3d); + } + }; + + // Like update_operator_arguments_from_configuration, but on the + // stored configuration from the last successful initialize() call, + // if any. If there was no last successful initialize() call, + // then return Status::kInvalid. + // + // Unfortunately, this must be declared const, because run() is. + Status update_operator_arguments_from_stored_configuration( + typename Operator::Arguments& out_args) const + { + return std::visit(ConfigurationVisitor{out_args}, last_successful_config_); + } + + template + struct UpdateFusionArgs { + static Status update_( + FusionArgs const&, + ConvArguments const&) + { + // For custom EVT, it is the user's responsibility to ensure + // that alpha and beta are updated appropriately. + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_( + FusionArgs& fusion_args, + ConvArguments const& arguments) + { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + static Status update_operator_arguments_from_configuration( + typename Operator::Arguments& out_args, + Conv2dConfiguration const& config) + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv2dConfiguration)\n"); +#endif + using detail::vector_to_array_strides; + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + if constexpr (num_spatial_dims != 2) { + CUTLASS_TRACE_HOST("You can only use Conv2dConfiguration " + "with an Operator whose NumSpatialDimensions is exactly 2."); + return Status::kInvalid; + } + else { + // Convolutions split the metadata (in Conv2dConfiguration) from + // the data (ConvArguments, which only has pointers and a single + // enum value). Thus, this class will need both the + // configuration and the (user's input) arguments to set up the + // kernel's arguments. This function can fill in what the + // configuration has now, but the class will need the user's + // input arguments later. + if (config.split_k_mode != conv::SplitKMode::kSerial) { + CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial."); + return Status::kInvalid; + } + // config.problem_size.split_k_slices is only meaningful if + // split_k_mode != kSerial. If this code later supports other + // split_k_mode values, then it will also need to read + // split_k_slices. + + const int N = config.problem_size.N; + const int H = config.problem_size.H; + const int W = config.problem_size.W; + const int C = config.problem_size.C; + const int K = config.problem_size.K; + const int R = config.problem_size.R; + const int S = config.problem_size.S; + const int pad_h = config.problem_size.pad_h; + const int pad_w = config.problem_size.pad_w; + const int traversal_stride_h = config.problem_size.stride_h; + const int traversal_stride_w = config.problem_size.stride_w; + const int dilation_h = config.problem_size.dilation_h; + const int dilation_w = config.problem_size.dilation_w; + + // CUTLASS 3's implicit GEMM convolution kernels currently only + // support cross correlation (passing over the activation and + // filter tensors in the same order). The convolution mode is + // future work. + const auto mode = config.problem_size.mode; + if (mode != cutlass::conv::Mode::kCrossCorrelation) { + CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation " + "are not currently supported."); + return Status::kInvalid; + } + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + constexpr size_t stride_size = size_t(num_spatial_dims) + 2u; + constexpr auto the_stride_size = std::integral_constant{}; + +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n" + << " stride_size = " << stride_size << "\n"; + auto print_stride = [] (auto const& stride, char const variable_name[]) { + std::cerr << " " << variable_name << ": ["; + for (size_t k = 0; k < stride.size(); ++k) { + std::cerr << stride[k]; + if (k + 1u < stride.size()) { + std::cerr << ", "; + } + } + std::cerr << "]\n"; + }; + print_stride(config.stride_a, "config.stride_a"); + print_stride(config.stride_b, "config.stride_b"); + print_stride(config.stride_c, "config.stride_c"); +#endif + + // Conv2dConfiguration stores the strides as std::vector, + // so the code needs to check the run-time vector lengths. + if (config.stride_a.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_a.size() + 1u = " + << (config.stride_a.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + if (config.stride_b.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_b.size() + 1u = " + << (config.stride_b.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + if (config.stride_c.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_c.size() + 1u = " + << (config.stride_c.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + + constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; + using problem_shape_type = + cutlass::conv::ConvProblemShape; + // cute::array; must convert to the kernel's native strides + using TensorStride = typename problem_shape_type::TensorStride; + + const TensorStride stride_A = vector_to_array_strides(config.stride_a, the_stride_size); + const TensorStride stride_B = vector_to_array_strides(config.stride_b, the_stride_size); + const TensorStride stride_C = vector_to_array_strides(config.stride_c, the_stride_size); + + // cutlass::library::Conv2dConfiguration has no member stride_d. + // The code below imitates the testbed, + // which just sets D's strides to C's strides. + + const int num_groups = config.problem_size.groups; + if (num_groups != 1) { + CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); + return Status::kInvalid; + } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // This means that stride_act isn't always config.stride_A, + // depending on Fprop / Dgrad / Wgrad. The code here "undoes" + // the logic in Conv2dWorkspace::set_stride_vector so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + + problem_shape_type problem_shape( + /* mode = */ mode, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, + /* lower_padding = */ {pad_h, pad_w}, + /* upper_padding = */ {pad_h, pad_w}, + /* traversal_stride = */ {traversal_stride_h, traversal_stride_w}, + /* dilation = */ {dilation_h, dilation_w}, + num_groups); + out_args.problem_shape = problem_shape; + + // ConvProblemShape's constructor sets its shape_C member. +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); +#endif + // Initialization of C's and D's strides follows the CUTLASS 3 + // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). + { + using StrideC = typename Operator::ConvKernel::StrideC; + using StrideD = typename Operator::ConvKernel::StrideD; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + + if constexpr (conv_op == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " Wgrad: stride_C: " << stride_C << "\n"; +#endif + } + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): " + << stride_C_i << "\n"; +#endif + cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): " + << stride_D_i << "\n"; +#endif + cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + } + out_args.epilogue.dC = stride_C; + out_args.epilogue.dD = stride_D; + } + return Status::kSuccess; + } + } + + static Status update_operator_arguments_from_configuration( + typename Operator::Arguments& out_args, + Conv3dConfiguration const& config) + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv3dConfiguration)\n"); +#endif + using detail::coord_to_array_strides; + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + if constexpr (num_spatial_dims != 3) { + CUTLASS_TRACE_HOST("You can only use Conv3dConfiguration " + "with an Operator whose NumSpatialDimensions is exactly 3."); + return Status::kInvalid; + } + else { + // Convolutions split the metadata (in Conv3dConfiguration) from + // the data (ConvArguments, which only has pointers and a single + // enum value). Thus, this class will need both the + // configuration and the (user's input) arguments to set up the + // kernel's arguments. This function can fill in what the + // configuration has now, but the class will need the user's + // input arguments later. + if (config.split_k_mode != conv::SplitKMode::kSerial) { + CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial."); + return Status::kInvalid; + } + // config.problem_size.split_k_slices is only meaningful if + // split_k_mode != kSerial. If this code later supports other + // split_k_mode values, then it will also need to read + // split_k_slices. + + const int N = config.problem_size.N; + const int D = config.problem_size.D; + const int H = config.problem_size.H; + const int W = config.problem_size.W; + const int C = config.problem_size.C; + const int K = config.problem_size.K; + const int T = config.problem_size.T; + const int R = config.problem_size.R; + const int S = config.problem_size.S; + const int pad_d = config.problem_size.pad_d; + const int pad_h = config.problem_size.pad_h; + const int pad_w = config.problem_size.pad_w; + const int traversal_stride_d = config.problem_size.stride_d; + const int traversal_stride_h = config.problem_size.stride_h; + const int traversal_stride_w = config.problem_size.stride_w; + const int dilation_d = config.problem_size.dilation_d; + const int dilation_h = config.problem_size.dilation_h; + const int dilation_w = config.problem_size.dilation_w; + + // CUTLASS 3's implicit GEMM convolution kernels currently only + // support cross correlation (passing over the activation and + // filter tensors in the same order). The convolution mode is + // future work. + const auto mode = config.problem_size.mode; + if (mode != cutlass::conv::Mode::kCrossCorrelation) { + CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation " + "are not currently supported."); + return Status::kInvalid; + } + + using Stride = cutlass::layout::TensorNDHWC::Stride; + static_assert(std::is_same_v>); + + const cutlass::library::ConvKind conv_kind = [] () { + constexpr cutlass::conv::Operator op = Operator::DispatchPolicy::ConvOp; + if constexpr (op == cutlass::conv::Operator::kFprop) { + return library::ConvKind::kFprop; + } + else if constexpr (op == cutlass::conv::Operator::kDgrad) { + return library::ConvKind::kDgrad; + } + else /* if constexpr (op == cutlass::conv::Operator::kWgrad) */ { + return library::ConvKind::kWgrad; + } + } (); + const Stride input_stride_a = config.layout_a(conv_kind).stride(); + const Stride input_stride_b = config.layout_b(conv_kind).stride(); + const Stride input_stride_c = config.layout_c(conv_kind).stride(); + +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + constexpr size_t stride_size = size_t(num_spatial_dims) + 2u; + std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n" + << " stride_size = " << stride_size << "\n"; + auto print_stride = [] (Stride const& stride, char const variable_name[]) { + std::cerr << " " << variable_name << ": ["; + for (size_t k = 0; k < Stride::kRank; ++k) { + std::cerr << stride[static_cast(k)]; + if (k + 1u < Stride::kRank) { + std::cerr << ", "; + } + } + std::cerr << "]\n"; + }; + print_stride(input_stride_a, "input_stride_a"); + print_stride(input_stride_b, "input_stride_b"); + print_stride(input_stride_c, "input_stride_c"); +#endif + // Conv3dConfiguration stores the strides as Coord (with + // compile-time size), so there's no need to check sizes here + // (unlike Conv2dConfiguration, which stores strides as + // std::vector). + + constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; + using problem_shape_type = + cutlass::conv::ConvProblemShape; + // cute::array; must convert to the kernel's native strides + using TensorStride = typename problem_shape_type::TensorStride; + + const TensorStride stride_A = coord_to_array_strides(input_stride_a); + const TensorStride stride_B = coord_to_array_strides(input_stride_b); + const TensorStride stride_C = coord_to_array_strides(input_stride_c); + + const int num_groups = config.problem_size.groups; + if (num_groups != 1) { + CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); + return Status::kInvalid; + } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // Conv3dConfiguration differs a bit from Conv2dConfiguration, + // but the idea is the same: the "input_stride_a" from config + // depends on conv_kind (Fprop, Dgrad, or Wgrad), so stride_act + // isn't always input_stride_a. Analogously, stride_flt isn't + // always input_stride_b. The code here "undoes" the logic in + // config.layout_a(conv_kind) and config.layout_b(conv_kind) + // (analogous to Conv2dWorkspace::set_stride_vector) so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, D, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, T, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + + problem_shape_type problem_shape( + /* mode = */ mode, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, + /* lower_padding = */ {pad_d, pad_h, pad_w}, + /* upper_padding = */ {pad_d, pad_h, pad_w}, + /* traversal_stride = */ {traversal_stride_d, traversal_stride_h, traversal_stride_w}, + /* dilation = */ {dilation_d, dilation_h, dilation_w}, + num_groups); + out_args.problem_shape = problem_shape; + + // ConvProblemShape's constructor sets its shape_C member. +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); +#endif + // Initialization of C's and D's strides follows the CUTLASS 3 + // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). + { + using StrideC = typename Operator::ConvKernel::StrideC; + using StrideD = typename Operator::ConvKernel::StrideD; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + + if constexpr (conv_op == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " Wgrad: stride_C: " << stride_C << "\n"; +#endif + } + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): " + << stride_C_i << "\n"; +#endif + cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): " + << stride_D_i << "\n"; +#endif + cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + } + out_args.epilogue.dC = stride_C; + out_args.epilogue.dD = stride_D; + } + return Status::kSuccess; + } + } + + Status update_operator_arguments_from_arguments( + typename Operator::Arguments& out_args, + ConvArguments const& in_args) const + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperation3x::update_operator_arguments_from_arguments\n"); +#endif + auto status = UpdateFusionArgs::update_( + out_args.epilogue.thread, in_args); + if (status != Status::kSuccess) { + return status; + } + + out_args.mainloop.ptr_A = reinterpret_cast(in_args.A); + out_args.mainloop.ptr_B = reinterpret_cast(in_args.B); + + out_args.epilogue.ptr_C = reinterpret_cast(in_args.C); + out_args.epilogue.ptr_D = reinterpret_cast(in_args.D); + + return Status::kSuccess; + } +}; + +} // namespace cutlass::library diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..880cb4bf34b1f3d946e1dc86b80806309bb2b3c1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h @@ -0,0 +1,1408 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_complex.h" +#include "cutlass/gemm/device/gemm_batched.h" +#include "cutlass/gemm/device/gemm_array.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + // assuming all tensors use same type for StrideIndex + using StrideIndex = typename Operator::LayoutA::Index; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmOperationBase(char const *name = "unknown_gemm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = GemmKind::kGemm; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::GemmKernel::WarpCount::kM, + Operator::GemmKernel::WarpCount::kN, + Operator::GemmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kGemm; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = {nullptr, configuration->lda}; + operator_args.ref_B = {nullptr, configuration->ldb}; + operator_args.ref_C = {nullptr, configuration->ldc}; + operator_args.ref_D = {nullptr, configuration->ldd}; + + operator_args.split_k_slices = configuration->split_k_slices; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + operator_args.ref_A.reset(static_cast(arguments->A)); + operator_args.ref_B.reset(static_cast(arguments->B)); + operator_args.ref_C.reset(static_cast(arguments->C)); + operator_args.ref_D.reset(static_cast(arguments->D)); + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + return op->run(stream); + } + + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmSparseOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementE = typename Operator::ElementE; + using LayoutE = typename Operator::LayoutE; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmSparseOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.kind = OperationKind::kSparseGemm; + this->description_.gemm_kind = GemmKind::kSparse; + this->description_.E = make_TensorDescription(Operator::kAlignmentE); + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + SparseGemmConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + operator_args.ref_A = {nullptr, configuration->lda}; + operator_args.ref_B = {nullptr, configuration->ldb}; + operator_args.ref_C = {nullptr, configuration->ldc}; + operator_args.ref_D = {nullptr, configuration->ldd}; + operator_args.ref_E = {nullptr, configuration->lde}; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + SparseGemmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(arguments->A)); + operator_args.ref_B.reset(static_cast(arguments->B)); + operator_args.ref_C.reset(static_cast(arguments->C)); + operator_args.ref_D.reset(static_cast(arguments->D)); + operator_args.ref_E.reset(static_cast(arguments->E)); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + SparseGemmConfiguration const *configuration = + static_cast(configuration_ptr); + + SparseGemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + return op->run(stream); + } + + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmUniversalOperation(char const *name = "unknown_gemm"): + GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmUniversalConfiguration const *configuration) { + + operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = (configuration->lda); + operator_args.ldb = (configuration->ldb); + operator_args.ldc = (configuration->ldc); + operator_args.ldd = (configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplex; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + + operator_args.lda_real = configuration->lda_real; + operator_args.lda_imag = configuration->lda_imag; + operator_args.ldb_real = configuration->ldb_real; + operator_args.ldb_imag = configuration->ldb_imag; + operator_args.ldc_real = configuration->ldc_real; + operator_args.ldc_imag = configuration->ldc_imag; + operator_args.ldd_real = configuration->ldd_real; + operator_args.ldd_imag = configuration->ldd_imag; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.batch_stride_A = arguments->batch_stride_A_real; + operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag; + operator_args.batch_stride_B = arguments->batch_stride_B_real; + operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag; + operator_args.batch_stride_C = arguments->batch_stride_C_real; + operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag; + operator_args.batch_stride_D = arguments->batch_stride_D_real; + operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag; + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexArrayOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplexArray; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda_real = configuration->lda_real; + operator_args.lda_imag = configuration->lda_imag; + operator_args.ldb_real = configuration->ldb_real; + operator_args.ldb_imag = configuration->ldb_imag; + operator_args.ldc_real = configuration->ldc_real; + operator_args.ldc_imag = configuration->ldc_imag; + operator_args.ldd_real = configuration->ldd_real; + operator_args.ldd_imag = configuration->ldd_imag; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.ptr_M = arguments->M; + operator_args.ptr_N = arguments->N; + operator_args.ptr_K = arguments->K; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexArrayConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArrayArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmGroupedOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmGroupedOperation(char const *name = "unknown_gemm"): + GemmOperationBase(name) { + + this->description_.kind = OperationKind::kGroupedGemm; + this->description_.provider = Provider::kCUTLASS; + this->threadblock_count = Operator::sufficient(); + + this->description_.gemm = GemmOperationBase::description_; + this->description_.gemm.gemm_kind = GemmKind::kGrouped; + this->description_.tile_description = this->description_.gemm.tile_description; + } + + /// Returns the description of the GroupedGEMM operation + virtual OperationDescription const & description() const override final { + return description_; + } + + +private: + int threadblock_count; + GroupedGemmDescription description_; + +protected: + + /// Constructs the arguments structure given the configuration and arguments + Status construct_arguments_( + OperatorArguments &op_args, + GemmGroupedConfiguration const *config) const { + + op_args.problem_count = config->problem_count; + op_args.threadblock_count = threadblock_count; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_( + OperatorArguments &op_args, + GemmGroupedArguments const *arguments) const { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + + op_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { + + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + + op_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + op_args.threadblock_count = threadblock_count; + op_args.problem_count = arguments->problem_count; + op_args.problem_sizes = arguments->problem_sizes; + + op_args.ptr_A = static_cast(arguments->ptr_A); + op_args.ptr_B = static_cast(arguments->ptr_B); + op_args.ptr_C = static_cast(arguments->ptr_C); + op_args.ptr_D = static_cast(arguments->ptr_D); + + op_args.lda = arguments->lda; + op_args.ldb = arguments->ldb; + op_args.ldc = arguments->ldc; + op_args.ldd = arguments->ldd; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmGroupedConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmGroupedArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2c1d17943f11fe8126b3070c3fcead5598e2d207 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp @@ -0,0 +1,714 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/array.h" +#include "cutlass/array_subbyte.h" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cute/tensor.hpp" +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperation3xBase : public Operation { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + // assuming all tensors use same type for StrideIndex + using StrideIndex = typename Operator::LayoutA::Index; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + +protected: + GemmDescription description_; + +public: + + /// Constructor + GemmOperation3xBase(char const *name = "unknown_gemm", GemmKind gemm_kind_ = GemmKind::kGemm) { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = gemm_kind_; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + GemmDescription const& get_gemm_description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversal3xOperation : public GemmOperation3xBase { +public: + + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + +public: + + /// Constructor + GemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { + dim3 cluster_dims( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{})); + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + } + } + +private: + int max_active_clusters{}; + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + template class Policy, int Stages, class ClusterShape, class KernelSchedule> + static constexpr bool is_sm90_mixed_dtype_mainloop_(Policy policy) { + return (cute::is_same_v, + cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput>); + } + + template + static constexpr bool is_sm90_mixed_dtype_mainloop_(DispatchPolicy) { + return false; + } + + template < + typename ElementWide, + typename ElementNarrow, + typename ElementScaleMainloop, + class ActualStrideAB, + Sm90MixedInputWiderOperand wider_operand, + bool is_n4w8, + typename ElementScale, + typename ElementZero, + class Layout_SZ> + static void dequantize_encode_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + cudaStream_t stream, + const int &problem_mn, + const int &problem_k, + const int &options_l, + const int &options_g, + ElementScale *ptr_S, + ElementZero *ptr_Z, + const size_t &SZ_size, + Layout_SZ layout_SZ + ) { + + auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); + auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); + auto layout_AB = cute::make_layout(shape_AB, stride_AB); + auto *ptr_dequantized_AB = static_cast(arguments->dequantized_AB); + const ElementNarrow *ptr_AB = nullptr; + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + ptr_AB = static_cast(arguments->B); + } + else { + ptr_AB = static_cast(arguments->A); + } + dequantize(ptr_dequantized_AB, ptr_AB, layout_AB, ptr_S, ptr_Z, layout_SZ, options_g, stream); + if constexpr(is_n4w8) { + size_t AB_size = cute::size(layout_AB); + cutlass::int4b_t *encoded_AB = static_cast(arguments->encoded_AB); + unified_encode_int4b(ptr_AB, encoded_AB, AB_size); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.ptr_B = static_cast(encoded_AB); + } + else { + operator_args.mainloop.ptr_A = static_cast(encoded_AB); + } + ElementScaleMainloop *ptr_packed_Scale = static_cast(arguments->packed_Scale); + pack_scale_fp8(ptr_S, ptr_packed_Scale, SZ_size); + } + } + + template < + typename ElementAB, + class ActualStrideAB, + class LayoutAB_Reordered, + class LayoutAtomQuant, + Sm90MixedInputWiderOperand wider_operand> + static void handle_shuffle_tensor_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + const int &problem_mn, + const int &problem_k, + const int &options_l) { + + auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); + auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); + auto layout_AB = cute::make_layout(shape_AB, stride_AB); + LayoutAB_Reordered layout_AB_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_AB); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.dB = layout_AB_reordered; + } + else { + operator_args.mainloop.dA = layout_AB_reordered; + } + if (arguments->generate_dequantized_AB) { + size_t AB_size = cute::size(layout_AB); + ElementAB *AB_reordered = cutlass::device_memory::allocate(AB_size); + const ElementAB *AB_src = nullptr; + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + AB_src = static_cast(operator_args.mainloop.ptr_B); + } + else { + AB_src = static_cast(operator_args.mainloop.ptr_A); + } + reorder_tensor(AB_src, layout_AB, AB_reordered, layout_AB_reordered); + ElementAB *AB_dst = static_cast(arguments->encoded_AB); + cutlass::device_memory::copy_device_to_device(AB_dst, AB_reordered, AB_size); + cutlass::device_memory::free(AB_reordered); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.ptr_B = AB_dst; + } + else { + operator_args.mainloop.ptr_A = AB_dst; + } + } + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_( + OperatorArguments& operator_args, + GemmUniversalArguments const* arguments, + cudaStream_t stream = nullptr) const { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + // TODO: type erase Arguments structure in 3.0 GEMM + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + // Stride{A,B} is a Layout if and only if: + // (1) This is a mixed dtype kernel, and + // (2) This mixed dtype kernel is using shuffling, and + // (3) sizeof(narrow_type) == 4 or 8 bits, and + // (4) sizeof(wide_type) == 16 bits. + // If A/B has the narrow data type, Stride{A/B} will be a Layout + constexpr bool is_StrideA_Layout = cute::is_layout::value; + constexpr bool is_StrideB_Layout = cute::is_layout::value; + static_assert(!(is_StrideA_Layout && is_StrideB_Layout), "Incorrect kernel configuration: StrideA and StrideB are both cute::Layout"); + if constexpr(!is_StrideA_Layout) { + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + } + if constexpr(!is_StrideB_Layout) { + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + } + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + using MainloopPolicy = typename CollectiveMainloop::DispatchPolicy; + if constexpr(is_sm90_mixed_dtype_mainloop_(MainloopPolicy{})) { + const int problem_m = arguments->problem_size.m(); + const int problem_n = arguments->problem_size.n(); + const int problem_k = arguments->problem_size.k(); + const int options_l = arguments->batch_count; + + constexpr Sm90MixedInputWiderOperand wider_operand = + (cutlass::sizeof_bits::value > cutlass::sizeof_bits::value) ? + Sm90MixedInputWiderOperand::A : Sm90MixedInputWiderOperand::B; + using ElementWide = std::conditional_t; + using ElementNarrow = std::conditional_t; + + constexpr bool has_scale = !std::is_same_v; + constexpr bool has_zero = !std::is_same_v; + + const int options_g = problem_k; + const int scale_k = (problem_k + options_g - 1) / options_g; + + constexpr bool is_A4B8 = ( + cutlass::is_same_v && + (cutlass::is_same_v || + cutlass::is_same_v)); + constexpr bool is_A8B4 = ( + cutlass::is_same_v && + (cutlass::is_same_v || + cutlass::is_same_v)); + constexpr bool is_int4_x_fp8 = is_A4B8 || is_A8B4; + + // If this is a convert-only kernel, we still need to generate dequantized A or B for verification, + // and in this case ElementScale is the same as ElementWide + // In int4 * fp8, ElementScale is a cutlass::Array, need to take out it's real element + using DummyElementScaleMainloop = std::conditional_t< + is_int4_x_fp8, + typename cutlass::Array, + ElementWide + >; + using ElementScaleMainloop = std::conditional_t< + has_scale, + typename CollectiveMainloop::ElementScale, + DummyElementScaleMainloop + >; + using ElementScale = std::conditional_t< + has_scale, + typename UnderlyingElement::type, + ElementWide + >; + using StrideScale = typename CollectiveMainloop::StrideScale; + // In ScaleOnly mode, we have allocated the same size of memory for arguments->Z and arguments->S + using ElementZero = std::conditional_t< + has_zero, + typename CollectiveMainloop::ElementZero, + ElementScale + >; + const int SZ_1st_dim = (wider_operand == Sm90MixedInputWiderOperand::A) ? problem_n : problem_m; + const size_t SZ_size = static_cast(SZ_1st_dim * scale_k * options_l); + auto shape_SZ = cute::make_shape(SZ_1st_dim, scale_k, options_l); + ElementScale *ptr_S = static_cast(arguments->Scale); + ElementZero *ptr_Z = static_cast(arguments->Zero); + + // 1. If arguments is initialized in profiler, S and Z needs to be allocated and filled + if (arguments->generate_scale_and_zero) { + float scale_min = 1.0f, scale_max = 1.0f; + if constexpr(has_scale) { + const float elt_max_f = float(cutlass::platform::numeric_limits::max()); + // Need to fix max_dequant_val and min_dequant_val? + const float max_dequant_val = elt_max_f * 0.25f; + const float min_dequant_val = 0.5f; + scale_max = max_dequant_val / elt_max_f; + scale_min = min_dequant_val / elt_max_f; + } + uint64_t seed = 2023; + cutlass::reference::device::BlockFillRandomUniform( + ptr_S, SZ_size, seed, ElementScale(scale_max), ElementScale(scale_min)); + + // In ScaleOnly mode, set Z as zero for generating dequantized A or B + const float zero_max = has_zero ? 2.0f : 0.0f; + const float zero_min = has_zero ? -2.0f : 0.0f; + cutlass::reference::device::BlockFillRandomUniform( + ptr_Z, SZ_size, seed, ElementZero(zero_max), ElementZero(zero_min)); + } // End of "if (arguments->generate_scale_and_zero)" + + // 2. Generate the dequantized A or B for verification + if (arguments->generate_dequantized_AB) { + StrideScale stride_SZ = cutlass::make_cute_packed_stride(StrideScale{}, shape_SZ); + auto layout_SZ = cute::make_layout(shape_SZ, stride_SZ); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + if constexpr(is_StrideB_Layout) { + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout of B later + using ActualLayoutB = cutlass::layout::ColumnMajor; + using ActualStrideB = cutlass::detail::TagToStrideB_t; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + else { + using ActualStrideB = typename CollectiveMainloop::StrideB; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + } + else { + if constexpr(is_StrideA_Layout) { + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout of A later + using ActualLayoutA = cutlass::layout::RowMajor; + using ActualStrideA = cutlass::detail::TagToStrideA_t; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + else { + using ActualStrideA = typename CollectiveMainloop::StrideA; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + } // End of "if constexpr(wider_operand == Sm90MixedInputWiderOperand::A)" + } // End of "if (arguments->generate_dequantized_AB)" + + // 3. Put Scale and Zero in mainloop + if constexpr(has_scale) { + if constexpr(is_int4_x_fp8) { + operator_args.mainloop.ptr_S = static_cast(arguments->packed_Scale); + } + else { + operator_args.mainloop.ptr_S = static_cast(arguments->Scale); + } + operator_args.mainloop.dS = cutlass::make_cute_packed_stride(StrideScale{}, shape_SZ); + operator_args.mainloop.group_size = options_g; + if constexpr(has_zero) { + operator_args.mainloop.ptr_Z = static_cast(arguments->Zero); + } + } // End of "if constexpr(has_scale)" + + // Handle the shuffling + using ValueShuffle = std::conditional_t< + cutlass::sizeof_bits::value == 4, + cute::Layout, cute::Stride>, + cute::Layout, cute::Stride> + >; + constexpr int NumShuffleAtoms = 1; + using MmaAtomShape = cute::Layout>>; + using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout and stride of A/B later + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A && is_StrideB_Layout) { + using ActualLayoutB = cutlass::layout::ColumnMajor; + using ActualStrideB = cutlass::detail::TagToStrideB_t; + using LayoutB_Reordered = typename CollectiveMainloop::StrideB; + handle_shuffle_tensor_( + operator_args, arguments, problem_n, problem_k, options_l); + } + if constexpr(wider_operand == Sm90MixedInputWiderOperand::B && is_StrideA_Layout) { + using ActualLayoutA = cutlass::layout::RowMajor; + using ActualStrideA = cutlass::detail::TagToStrideA_t; + using LayoutA_Reordered = typename CollectiveMainloop::StrideA; + handle_shuffle_tensor_( + operator_args, arguments, problem_m, problem_k, options_l); + } + } // End of "if constexpr(is_sm90_mixed_dtype_mainloop_(MainloopPolicy{}))" + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { + operator_args.hw_info.max_active_clusters = max_active_clusters; + } + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + [[maybe_unused]] void const *configuration_ptr, void const *arguments_ptr) const override { + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + OperatorArguments args; + + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + Status can_impl = Operator::can_implement(args); + + //return Operator::can_implement(args); + return can_impl; + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr), stream); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, + static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..91f618d4fab74a6d43e2d82c572d215d5bea5a1c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp @@ -0,0 +1,873 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all grouped GEMM operations in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "gemm_operation_3x.hpp" +#include "library_internal.h" + +namespace cutlass::library { + +template +class GroupedGemmOperation3xBase : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + GroupedGemmOperation3xBase(char const* name = "unknown_gemm") + : GemmOperation3xBase(name, GemmKind::kGrouped) { + this->description_.kind = OperationKind::kGroupedGemm; + this->description_.name = name; + this->description_.provider = Provider::kCUTLASS; + + this->description_.gemm = GemmOperation3xBase::description_; + this->description_.tile_description = this->description_.gemm.tile_description; + }; + +public: + mutable CudaBuffer strideA_device; + mutable CudaBuffer strideB_device; + mutable CudaBuffer strideC_device; + mutable CudaBuffer strideD_device; + + /// Returns the description of the GEMM operation + virtual OperationDescription const& description() const override final { return description_; } + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const* configuration) const override final { + return sizeof(Operator); + } + +protected: + library::GroupedGemmDescription description_; + + Status initialize_strides(GemmGroupedConfiguration const& config) const { + auto const num_groups = config.problem_count; + this->strideA_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups); + this->strideB_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups); + this->strideC_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups); + this->strideD_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups); + + std::vector strideA_host(num_groups); + std::vector strideB_host(num_groups); + std::vector strideC_host(num_groups); + std::vector strideD_host(num_groups); + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + strideA_host[group_idx] = + cute::make_int_tuple_from( + config.lda[group_idx]); + strideB_host[group_idx] = + cute::make_int_tuple_from( + config.ldb[group_idx]); + strideC_host[group_idx] = + cute::make_int_tuple_from( + config.ldc[group_idx]); + strideD_host[group_idx] = + cute::make_int_tuple_from( + config.ldc[group_idx]); + } + CUDA_CHECK(cudaMemcpy( + this->strideA_device.data(), + strideA_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideB_device.data(), + strideB_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideC_device.data(), + strideC_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideD_device.data(), + strideD_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups, + cudaMemcpyHostToDevice)); + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_base( + OperatorArguments& operator_args, + GemmGroupedArguments const& arguments) const { + operator_args.mode = cutlass::gemm::GemmUniversalMode::kGrouped; + operator_args.problem_shape = { + arguments.problem_count, + arguments.problem_sizes_3x, + arguments.pointer_mode == ScalarPointerMode::kHost ? arguments.problem_sizes_3x_host + : nullptr}; + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); + operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + + using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; + + static_assert(cute::is_same_v, + "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); + using RuntimeDatatypeArg = RuntimeDataTypeA; + + auto mapping = [](RuntimeDatatype type) { + if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE5M2) { + return cute::UMMA::MXF8F6F4Format::E5M2; + } + else if (type == RuntimeDatatype::kE4M3) { + return cute::UMMA::MXF8F6F4Format::E4M3; + } + else if (type == RuntimeDatatype::kE3M2) { + return cute::UMMA::MXF8F6F4Format::E3M2; + } + else if (type == RuntimeDatatype::kE2M3) { + return cute::UMMA::MXF8F6F4Format::E2M3; + } + else if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF8F6F4Format::E2M1; + } + else { + #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 + std::cerr << "Invalid input datatype specified. Running with e4m3." << std::endl; + #endif + return cute::UMMA::MXF8F6F4Format::E4M3; + } + } + else if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF4Format::E2M1; + } + else { + #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 + std::cerr << "Invalid input datatype specified. Running with e2m1." << std::endl; + #endif + return cute::UMMA::MXF4Format::E2M1; + } + } + // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype + CUTE_GCC_UNREACHABLE; + }; + operator_args.mainloop.runtime_data_type_a = mapping(arguments.runtime_input_datatype_a); + operator_args.mainloop.runtime_data_type_b = mapping(arguments.runtime_input_datatype_b); + } + else { + operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); + operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + } + operator_args.epilogue.ptr_C = static_cast(arguments.ptr_C); + operator_args.epilogue.ptr_D = static_cast(arguments.ptr_D); + + operator_args.mainloop.dA = + static_cast(this->strideA_device.data()); + operator_args.mainloop.dB = + static_cast(this->strideB_device.data()); + operator_args.epilogue.dC = + static_cast(this->strideC_device.data()); + operator_args.epilogue.dD = + static_cast(this->strideD_device.data()); + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments.sm_count; + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + operator_args.hw_info.max_active_clusters = arguments.max_active_clusters; + } + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments.swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments.raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = + dim3(arguments.cluster_shape.m(), arguments.cluster_shape.n(), arguments.cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments.cluster_shape_fallback.m(), + arguments.cluster_shape_fallback.n(), + arguments.cluster_shape_fallback.k()); + } + return Status::kSuccess; + } + + template + static Status update_fusion_args(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } +}; + +/// **** CAUTION **** +/// Unlike other operations, initialize() must be called when +/// certain arguments change. See initialize() for details. +template +class GroupedGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + +public: + GroupedGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) {} + + ~GroupedGemmUniversal3xOperation() override = default; + +private: + int max_active_clusters{}; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + + /// Constructs the arguments structure given the configuration and arguments + Status + update_arguments_(OperatorArguments& operator_args, GemmGroupedArguments const* arguments) const { + + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + status = this->update_arguments_base(operator_args, *arguments); + return status; + } + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GemmGroupedArguments const* arguments = static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + Operator* op = new (host_workspace) Operator; + + auto const& config = *static_cast(configuration_ptr); + return this->initialize_strides(config); + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(operator_args, device_workspace, stream, nullptr, args.use_pdl); + return status; + } + + // Set arguments that should only be set once before verifying or profiling the kernel. + // This should encompass any expensive operations that don't vary from run to run + // (e.g., max_active_clusters). + Status initialize_with_arguments(void* arguments_ptr) const override { + if constexpr (Operator::ArchTag::kMinComputeCapability < 90) { + return Status::kSuccess; + } + + GemmGroupedArguments* args = static_cast(arguments_ptr); + + dim3 cluster_dims; + if constexpr (cute::is_static_v) { + cluster_dims = dim3( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{}) + ); + } + else { + cluster_dims = dim3( + args->cluster_shape.m(), + args->cluster_shape.n(), + args->cluster_shape.k() + ); + } + + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + args->max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + + if (args->max_active_clusters == 0) { + std::cerr << "Max Active Clusters could not be queried. " + << "Falling back to heuristics mode (static cluster shape) or preferred cluster mode.\n"; + } + + return Status::kSuccess; + } +}; + +template +class GroupedBlockScaledGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; + using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + constexpr static int SFVecSize = TiledMma::SFVecSize; + + + static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; + static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; + using ElementSFD = cute::conditional_t; + using LayoutSFD = cute::conditional_t; + + GroupedBlockScaledGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) { + + BlockScaleDescription block_scaled_desc{}; + block_scaled_desc.kind = OperationKind::kBlockScaledGemm; + block_scaled_desc.SFA.element = NumericTypeMap::kId; + block_scaled_desc.SFA.layout = LayoutTypeID::kRowMajor; + block_scaled_desc.SFA.alignment = 128; + block_scaled_desc.SFA.log_extent_range = 32; + block_scaled_desc.SFA.log_stride_range = 32; + + block_scaled_desc.SFB.element = NumericTypeMap::kId; + block_scaled_desc.SFB.layout = LayoutTypeID::kRowMajor; + block_scaled_desc.SFB.alignment = 128; + block_scaled_desc.SFB.log_extent_range = 32; + block_scaled_desc.SFB.log_stride_range = 32; + + block_scaled_desc.SFMVecSize = 1; + block_scaled_desc.SFNVecSize = 1; + block_scaled_desc.SFKVecSize = SFVecSize; + + block_scaled_desc.SFD = make_TensorDescription(128); + block_scaled_desc.EpilogueSFVecSize = SFD_VectorSize; + + this->description_.block_scales = block_scaled_desc; + } + + ~GroupedBlockScaledGemmUniversal3xOperation() override = default; + + mutable CudaBuffer layout_SFA_device; + mutable CudaBuffer layout_SFB_device; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status + update_(FusionArgs& fusion_args, GroupedGemmBlockScaledArguments const& arguments) { + + if constexpr (epilogue_scalefactor_generation) { + fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); + fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); + } + + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GroupedGemmBlockScaledArguments const* arguments = + static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + Status update_arguments_( + OperatorArguments& operator_args, + GroupedGemmBlockScaledArguments const* arguments) const { + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.mainloop.ptr_SFA = + static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = + static_cast(arguments->SFB); + + operator_args.mainloop.layout_SFA = + static_cast(this->layout_SFA_device.data()); + operator_args.mainloop.layout_SFB = + static_cast(this->layout_SFB_device.data()); + + return this->update_arguments_base(operator_args, *arguments); + } + + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = + update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + auto const& config = *static_cast(configuration_ptr); + auto status = this->initialize_strides(config); + if (status != Status::kSuccess) { + return status; + } + + auto num_groups = config.problem_count; + this->layout_SFA_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups); + this->layout_SFB_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups); + auto layout_SFA_host = std::vector(num_groups); + auto layout_SFB_host = std::vector(num_groups); + + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + auto const& shape = config.problem_sizes_3x_host[group_idx]; + auto M = get<0>(shape); + auto N = get<1>(shape); + auto K = get<2>(shape); + + auto layout_SFA = CollectiveMainloop::Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = CollectiveMainloop::Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + layout_SFA_host[group_idx] = layout_SFA; + layout_SFB_host[group_idx] = layout_SFB; + } + + CUDA_CHECK(cudaMemcpy( + this->layout_SFA_device.data(), + layout_SFA_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->layout_SFB_device.data(), + layout_SFB_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups, + cudaMemcpyHostToDevice)); + + Operator* op = new (host_workspace) Operator; + return status; + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + status = op->run(operator_args, device_workspace, stream, nullptr); + return status; + } +}; + +template +class GroupedBlockwiseGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using ElementSFA = typename Operator::ElementAccumulator; + using ElementSFB = typename Operator::ElementAccumulator; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + + GroupedBlockwiseGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) { + + BlockScaleDescription blockwise_desc{}; + blockwise_desc.kind = OperationKind::kBlockwiseGemm; + blockwise_desc.SFA.element = NumericTypeMap::kId; + blockwise_desc.SFA.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFA{}.stride()) == 1 ? + LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + blockwise_desc.SFA.alignment = CollectiveMainloop::AlignmentSFA; + blockwise_desc.SFA.log_extent_range = 32; + blockwise_desc.SFA.log_stride_range = 32; + + blockwise_desc.SFB.element = NumericTypeMap::kId; + blockwise_desc.SFB.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFB{}.stride()) == 1 ? + LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor; + blockwise_desc.SFB.alignment = CollectiveMainloop::AlignmentSFA; + blockwise_desc.SFB.log_extent_range = 32; + blockwise_desc.SFB.log_stride_range = 32; + + blockwise_desc.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM; + blockwise_desc.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN; + blockwise_desc.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK; + + blockwise_desc.EpilogueSFVecSize = 0; + + this->description_.block_scales = blockwise_desc; + } + + ~GroupedBlockwiseGemmUniversal3xOperation() override = default; + + mutable CudaBuffer layout_SFA_device; + mutable CudaBuffer layout_SFB_device; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status + update_(FusionArgs& fusion_args, GroupedGemmBlockwiseArguments const& arguments) { + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GroupedGemmBlockwiseArguments const* arguments = + static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + Status update_arguments_( + OperatorArguments& operator_args, + GroupedGemmBlockwiseArguments const* arguments) const { + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.mainloop.ptr_SFA = + static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = + static_cast(arguments->SFB); + + operator_args.mainloop.layout_SFA = + static_cast(this->layout_SFA_device.data()); + operator_args.mainloop.layout_SFB = + static_cast(this->layout_SFB_device.data()); + + return this->update_arguments_base(operator_args, *arguments); + } + + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = + update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + auto const& config = *static_cast(configuration_ptr); + auto status = this->initialize_strides(config); + if (status != Status::kSuccess) { + return status; + } + + auto num_groups = config.problem_count; + this->layout_SFA_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups); + this->layout_SFB_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups); + auto layout_SFA_host = std::vector(num_groups); + auto layout_SFB_host = std::vector(num_groups); + + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + auto const& shape = config.problem_sizes_3x_host[group_idx]; + auto M = get<0>(shape); + auto N = get<1>(shape); + auto K = get<2>(shape); + + auto layout_SFA = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + layout_SFA_host[group_idx] = layout_SFA; + layout_SFB_host[group_idx] = layout_SFB; + } + + CUDA_CHECK(cudaMemcpy( + this->layout_SFA_device.data(), + layout_SFA_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->layout_SFB_device.data(), + layout_SFB_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups, + cudaMemcpyHostToDevice)); + + Operator* op = new (host_workspace) Operator; + return status; + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + status = op->run(operator_args, device_workspace, stream, nullptr); + return status; + } +}; + + +} // namespace cutlass::library diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/library_internal.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/library_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..e8bd77397f3b85cce2da2a7a8e447ab6ccb48aea --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/library_internal.h @@ -0,0 +1,427 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/arch_mappings.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct NumericTypeMap; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kVoid; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kB1; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS4; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU4; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE4M3; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE5M2; +}; + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE2M3; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE3M2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE2M1; +}; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFUE8M0; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFUE4M3; +}; + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF64; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF16; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF32; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kBF16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kTF32; +}; + + + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF6; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF4; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kInvalid; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAdd; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddMixedInputUpcast; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddGaussianComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kXorPopc; +}; + + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastF32; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplexFastF32; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct LayoutMap; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kRowMajor; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct OpcodeClassMap; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kSimt; +}; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kTensorOp; +}; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kSparseTensorOp; +}; + + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kBlockScaledOp; +}; + + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ComplexTransformMap; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone; +}; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ConvModeMap; + +template <> struct ConvModeMap { + static ConvModeID const kId = ConvModeID::kCrossCorrelation; +}; + +template <> struct ConvModeMap { + static ConvModeID const kId = ConvModeID::kConvolution; +}; + + +template struct ConvKindMap; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kFprop; +}; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kDgrad; +}; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kWgrad; +}; + + +template struct IteratorAlgorithmMap; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFixedChannels; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFewChannels; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +TensorDescription make_TensorDescription(int alignment = 1) { + TensorDescription desc; + + desc.element = NumericTypeMap::kId; + desc.layout = LayoutMap::kId; + desc.alignment = alignment; + desc.log_extent_range = int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; + desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; + + return desc; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..76d8d0dfdb1aa6ed0324b9d6299b06ebf3f436d9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Rank 2K operation kinds (Syr2k, Her2k) + in CUTLASS Library. + + +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/rank_2k.h" +#include "cutlass/gemm/kernel/default_rank_2k_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Rank2KOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + RankKDescription description_; + +public: + + /// Constructor + Rank2KOperationBase(char const *name = "unknown_rank_k") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.rank_k_kind = RankKKind::kUniversal; + description_.fill_mode = kFillModeC; + description_.blas_mode = kBlasMode; + description_.num_ranks = kUpdateRank; + + description_.kind = OperationKind::kRank2K; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::Rank2Kkernel::WarpCount::kM, + Operator::Rank2Kkernel::WarpCount::kN, + Operator::Rank2Kkernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the SYRK operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Rank2KOperation : public Rank2KOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + Rank2KOperation(char const *name = "unknown_rank_2k"): + Rank2KOperationBase(name) { + + this->description_.rank_k_kind = RankKKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + RankKConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + RankKArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + RankKConfiguration const *configuration = + static_cast(configuration_ptr); + + RankKArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + //std::cout << "initialize() library::Rank2KOperation" << std::endl; + //print_operator_args(args); + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run() library::Rank2KOperation" << std::endl; + //print_operator_args(args); + status = op->run(stream); + + return status; + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Rank2KOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " epilogue (alpha, beta): " + << operator_args.epilogue.alpha << ", " + << operator_args.epilogue.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ptr_A << ", {" + << operator_args.lda << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ptr_B << ", {" + << operator_args.ldb << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ptr_C << ", {" + << operator_args.ldc << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ptr_D << ", {" + << operator_args.ldd << "}" << std::endl; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..021f7f03fcc4449bdc2ef2c97e29fe0fead09a64 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Rank K operation kinds (Syrk, Herk) + in CUTLASS Library. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/rank_k.h" +#include "cutlass/gemm/kernel/default_rank_k_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class RankKOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementA; + using LayoutB = typename Operator::LayoutA; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + RankKDescription description_; + +public: + + /// Constructor + RankKOperationBase(char const *name = "unknown_rank_k") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.rank_k_kind = RankKKind::kUniversal; + description_.fill_mode = kFillModeC; + description_.blas_mode = kBlasMode; + description_.num_ranks = kUpdateRank; + + description_.kind = OperationKind::kRankK; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::RankKkernel::WarpCount::kM, + Operator::RankKkernel::WarpCount::kN, + Operator::RankKkernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentA); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the SYRK operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class RankKOperation : public RankKOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementA; + using LayoutB = typename Operator::LayoutA; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + RankKOperation(char const *name = "unknown_rank_k"): + RankKOperationBase(name) { + + this->description_.rank_k_kind = RankKKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + RankKConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->lda); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + RankKArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + RankKConfiguration const *configuration = + static_cast(configuration_ptr); + + RankKArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..6e948540e3f29dceace42b5e8ef3f91118c01b37 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for reduction operation in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/reduction/thread/reduction_operators.h" +#include "cutlass/reduction/device/reduce_split_k.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ReductionOperation : public Operation { +public: + using Operator = Operator_; + + using ElementWorkspace = typename Operator::ElementWorkspace; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementOutput = typename Operator::ElementOutput; + + using ElementCompute = typename Operator::OutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ReductionDescription description_; + +public: + + /// Constructor + ReductionOperation(char const *name = "unknown_reduction") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kReduction; + + description_.tile_description.threadblock_shape = make_Coord(Operator::Shape::kRow, Operator::Shape::kColumn, 1); + + description_.tile_description.math_instruction.instruction_shape = make_Coord(1, 1, 1); + description_.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; + description_.tile_description.math_instruction.opcode_class = OpcodeClassID::kSimt; + description_.tile_description.math_instruction.math_operation = MathOperationID::kAdd; + + description_.tile_description.minimum_compute_capability = 50; + description_.tile_description.maximum_compute_capability = 1024; + + description_.element_workspace = NumericTypeMap::kId; + description_.element_output = NumericTypeMap::kId; + description_.element_epilogue = NumericTypeMap::kId; + + } + + /// Returns the description of the Reduction operation + virtual OperationDescription const & description() const { + return description_; + } + + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + ReductionConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + operator_args.partitions = configuration->partitions; + operator_args.partition_stride = configuration->partition_stride; + + operator_args.workspace = {nullptr, int(configuration->ldw)}; + operator_args.source = {nullptr, int(configuration->lds)}; + operator_args.destination = {nullptr, int(configuration->ldd)}; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ReductionArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::OutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::OutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.workspace.reset(static_cast(const_cast(arguments->workspace))); + operator_args.source.reset(static_cast(const_cast(arguments->source))); + operator_args.destination.reset(static_cast(const_cast(arguments->destination))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + ReductionConfiguration const *configuration = + static_cast(configuration_ptr); + + ReductionArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Reduction" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run library::Reduction" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Reduction::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Reduction::OperatorArguments" << std::endl + << " problem_size: " + << operator_args.problem_size << std::endl + << " partitions: " + << operator_args.partitions << std::endl + << " partition_stride: " + << operator_args.partition_stride << std::endl + << " epilogue (alpha, beta): " + << operator_args.output.alpha << ", " + << operator_args.output.beta << std::endl + << " workspace (ptr, stride): " + << operator_args.workspace.data() << ", " + << operator_args.workspace.stride(0) << std::endl + << " source (ptr, stride): " + << operator_args.source.data() << ", " + << operator_args.source.stride(0) << std::endl + << " destination (ptr, stride): " + << operator_args.destination.data() << ", " + << operator_args.destination.stride(0) << std::endl; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..769da1c8515877536fd9b9fd72c836fd43ebd5d8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h @@ -0,0 +1,453 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for block-scaled GEMM operation kinds in CUTLASS Library +*/ + + + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "cutlass/util/packed_stride.hpp" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +namespace detail { +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + typename ElementSFA_, + typename ElementB_, + typename LayoutB_, + typename ElementSFB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ElementSFD_ = void, + typename LayoutSFD_ = LayoutC_, + int SFVecSize_ = 32, + int EpilogueSFVecSize_ = 0, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class BlockScaledGemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementSFA = ElementSFA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementSFB = ElementSFB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using ElementSFD = ElementSFD_; + using LayoutSFD = LayoutSFD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + constexpr static int SFVecSize = SFVecSize_; + constexpr static int EpilogueSFVecSize = EpilogueSFVecSize_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + BlockScaledGemmDescription description_; + +public: + + /// Constructor + BlockScaledGemmReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kBlockScaledGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.SFA = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.SFB = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + description_.SFD = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + description_.SFVecSize = SFVecSize; + description_.EpilogueSFVecSize = EpilogueSFVecSize; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.SFA.element) << to_string(description_.SFA.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.SFB.element) << to_string(description_.SFB.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.SFD.element) << to_string(description_.SFD.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + using namespace cute; + + BlockScaledGemmArguments const &args = *static_cast(arguments); + + // Construct cute::Tensor A/B/C + + int M = args.problem_size.m(); + int N = args.problem_size.n(); + int K = args.problem_size.k(); + int L = args.batch_count; + + auto problem_shape_MNKL = cute::make_shape(M, N, K, L); + + auto alpha = *(static_cast(args.alpha)); + auto beta = *(static_cast(args.beta)); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + using Sm1xxBlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + auto A = cute::make_tensor(detail::make_iterator(static_cast(args.A)), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(static_cast(args.SFA), Sm1xxBlockScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL)); + + auto B = cute::make_tensor(detail::make_iterator(static_cast(args.B)), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(static_cast(args.SFB), Sm1xxBlockScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL)); + + auto C = [&]() { + if constexpr (not is_same_v) { + return cute::make_tensor(detail::make_iterator(static_cast(args.C)), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + else { + return cute::make_tensor(detail::make_iterator(static_cast(nullptr)), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + }(); + + auto D = cute::make_tensor(detail::make_iterator(static_cast(args.D)), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{A, SfA, B, SfB}; + + if constexpr (not is_same_v) { + + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + EpilogueSFVecSize + >; + + auto SfD = cute::make_tensor(detail::make_iterator(static_cast(args.SFD)), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, ElementAccumulator, ElementCompute, + decltype(C), decltype(D), decltype(SfD), Int, cutlass::reference::host::SfStrategy::SfDGen> + epilogue_params{alpha, beta, C, D, SfD, *(static_cast(args.norm_constant))}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + else { + // W/O SF generation + auto SfD = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L))); // not used. + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, ElementAccumulator, ElementCompute, + decltype(C), decltype(D), decltype(SfD)> + epilogue_params{alpha, beta, C, D, SfD}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + + return Status::kSuccess; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementSFD_ = void, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + int SFVecSize = 32, + int EpilogueSFVecSize = SFVecSize, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_block_scaled_gemm_tn(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementSFD_ = void, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + int SFVecSize = 32, + int EpilogueSFVecSize = SFVecSize, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_block_scaled_gemm(Manifest &manifest) { + /// + /// A is Row , B is Col + /// + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + /// + /// A is Col , B is Row + /// + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..fd988f899f563acfc6f8003bdb49523bca51d6d9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h @@ -0,0 +1,807 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for blockwise/groupwise GEMM operation kinds in CUTLASS Library +*/ + + + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "cutlass/util/packed_stride.hpp" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + typename LayoutSFA_, + typename ElementSFA_, + typename ElementB_, + typename LayoutB_, + typename LayoutSFB_, + typename ElementSFB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class BlockwiseGemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementSFA = ElementSFA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementSFB = ElementSFB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + BlockwiseGemmDescription description_; + +public: + + /// Constructor + BlockwiseGemmReferenceOperation(int SFMVecSize_, int SFNVecSize_, int SFKVecSize_) + : SFMVecSize(SFMVecSize_), SFNVecSize(SFNVecSize_), SFKVecSize(SFKVecSize_) { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kBlockwiseGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.SFA = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.SFB = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + description_.SFMVecSize = SFMVecSize; + description_.SFNVecSize = SFNVecSize; + description_.SFKVecSize = SFKVecSize; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.SFA.element) << SFMVecSize << "x" << SFKVecSize << to_string(description_.SFA.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.SFB.element) << SFNVecSize << "x" << SFKVecSize << to_string(description_.SFB.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + using namespace cute; + + BlockwiseGemmArguments const &args = *static_cast(arguments); + + // Construct cute::Tensor A/B/C + + int M = args.problem_size.m(); + int N = args.problem_size.n(); + int K = args.problem_size.k(); + int L = args.batch_count; + + auto problem_shape_MNKL = cute::make_shape(M, N, K, L); + + auto alpha = *(static_cast(args.alpha)); + auto beta = *(static_cast(args.beta)); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + using BlockwiseConfig = cutlass::detail::RuntimeBlockwiseScaleConfig<>; + auto A = cute::make_tensor(static_cast(args.A), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(static_cast(args.SFA), BlockwiseConfig::tile_atom_to_shape_SFA(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); + + auto B = cute::make_tensor(static_cast(args.B), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(static_cast(args.SFB), BlockwiseConfig::tile_atom_to_shape_SFB(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); + + auto C = [&]() { + if constexpr (not is_same_v) { + return cute::make_tensor(static_cast(args.C), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + else { + return cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + }(); + + auto D = cute::make_tensor(static_cast(args.D), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{A, SfA, B, SfB}; + + // W/O SF generation + cutlass::reference::host::GettEpilogueParams< + ElementCompute, ElementAccumulator, ElementAccumulator, ElementCompute, + decltype(C), decltype(D)> + epilogue_params{alpha, beta, C, D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + return Status::kSuccess; + } + +private: + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_blockwise_gemm(Manifest &manifest, int SFMVecSize, int SFNVecSize, int SFKVecSize) { + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + +} + +template +void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &manifest) { + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..240fe18d16a27778bf75e0c02f99d251c096353f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h @@ -0,0 +1,636 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "library_internal.h" + +#include "cutlass/conv/convolution.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + Provider kProvider, + cutlass::conv::Operator ConvolutionalOperator, + int ConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +struct ConvReferenceDispatcher; + +/// Dispatcher for Conv2d (partially specialized for kConvDim == 2) +template < + Provider kProvider, + cutlass::conv::Operator kConvolutionalOperator, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator, + typename ConvertOp, + typename InnerProductOp +> +struct ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + 2, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp> { + + static Status dispatch( + void const *configuration, + ElementA *ptr_A, + ElementB *ptr_B, + ElementC *ptr_C, + ElementC *ptr_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr + ) { + + Conv2dConfiguration const &config = + *static_cast(configuration); + + // TODO: make below code more general. It is fixed for NHWC now. + layout::TensorNHWC layout_a; + layout::TensorNHWC layout_b; + layout::TensorNHWC layout_c; + + layout_a.stride() = + make_Coord(int32_t(config.stride_a[0]), + int32_t(config.stride_a[1]), + int32_t(config.stride_a[2])); + + layout_b.stride() = + make_Coord(int32_t(config.stride_b[0]), + int32_t(config.stride_b[1]), + int32_t(config.stride_b[2])); + + layout_c.stride() = + make_Coord(int32_t(config.stride_c[0]), + int32_t(config.stride_c[1]), + int32_t(config.stride_c[2])); + + if (kProvider == Provider::kReferenceHost) { + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC , + LayoutC, + ElementCompute, + ElementAccumulator, + ElementC, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, layout_a}, + {ptr_B, layout_b}, + {ptr_C, layout_c}, + {ptr_D, layout_c}, + alpha, + beta + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + return cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, layout_a}, + {ptr_B, layout_b}, + {ptr_C, layout_c}, + {ptr_D, layout_c}, + alpha, + beta, + stream + ); + } + return Status::kErrorNotSupported; + } +}; + +/// Dispatcher for Conv3d (partially specialized for kConvDim == 3) +template < + Provider kProvider, + cutlass::conv::Operator kConvolutionalOperator, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator, + typename ConvertOp, + typename InnerProductOp +> +struct ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + 3, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp> { + + static Status dispatch( + void const *configuration, + ElementA *ptr_A, + ElementB *ptr_B, + ElementC *ptr_C, + ElementC *ptr_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr + ) { + + Conv3dConfiguration const &config = + *static_cast(configuration); + + ConvKind const conv_kind = ConvKindMap::kId; + + if (kProvider == Provider::kReferenceHost) { + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC , + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, config.layout_a(conv_kind)}, + {ptr_B, config.layout_b(conv_kind)}, + {ptr_C, config.layout_c(conv_kind)}, + {ptr_D, config.layout_c(conv_kind)}, + alpha, + beta + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + return cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, config.layout_a(conv_kind)}, + {ptr_B, config.layout_b(conv_kind)}, + {ptr_C, config.layout_c(conv_kind)}, + {ptr_D, config.layout_c(conv_kind)}, + alpha, + beta, + stream + ); + } + return Status::kErrorNotSupported; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + cutlass::conv::Operator ConvolutionalOperator, + int ConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class ConvReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + static cutlass::conv::Operator const kConvolutionalOperator = ConvolutionalOperator; + static int const kConvDim = ConvDim; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + ConvDescription description_; + +public: + + /// Constructor + ConvReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = (kConvDim == 2 ? OperationKind::kConv2d : OperationKind::kConv3d); + description_.conv_kind = ConvKindMap::kId; + description_.conv_dim = kConvDim; + + // Tensor description + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Iterator algorithm for convolution reference + description_.iterator_algorithm = IteratorAlgorithmID::kNone; + + // Compute capability for convolution reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + // Procedural name + std::stringstream ss; + + ss << "conv" << kConvDim << "d_" << to_string(description_.conv_kind) + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + switch (kConvDim) { + case 2: + return sizeof(Conv2dConfiguration); + case 3: + return sizeof(Conv3dConfiguration); + default: + break; + } + + return 0; + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); + + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + ConvArguments const &args = *static_cast(arguments); + + ElementCompute alpha; + ElementCompute beta; + + alpha = *static_cast(args.alpha); + beta = *static_cast(args.beta); + + // TODO - respect pointer mode + + // Invoke 2D or 3D convolution + return detail::ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + kConvDim, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >::dispatch( + host_workspace, + static_cast(const_cast(args.A)), + static_cast(const_cast(args.B)), + static_cast(const_cast(args.C)), + static_cast(args.D), + alpha, + beta, + stream + ); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs Fprop reference operators. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_fprop(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kFprop, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kFprop, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/// Constructs Dgrad and Wgrad reference operators. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_backwards(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kDgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kDgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kWgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kWgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/// Six operators for the price of one. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_all(Manifest &manifest) { + + make_conv_fprop< + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_conv_backwards< + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..e07158b0602eef1d71cfdca95323b3da60553747 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h @@ -0,0 +1,543 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for GEMM operation kinds in CUTLASS Library +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + cutlass::ComplexTransform TransformA, + typename ElementB_, + typename LayoutB_, + cutlass::ComplexTransform TransformB, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class GemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + static cutlass::ComplexTransform const kTransformA = TransformA; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + static cutlass::ComplexTransform const kTransformB = TransformB; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.transform_A = ComplexTransformMap::kId; + description_.B = make_TensorDescription(); + description_.transform_B = ComplexTransformMap::kId; + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); + + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + GemmUniversalConfiguration const &config = *static_cast(host_workspace); + GemmUniversalArguments const &args = *static_cast(arguments); + + TensorRefA ref_A{static_cast(const_cast(args.A)), LayoutA(int(config.lda))}; + TensorRefB ref_B{static_cast(const_cast(args.B)), LayoutB(int(config.ldb))}; + TensorRefC ref_C{static_cast(const_cast(args.C)), LayoutC(int(config.ldc))}; + TensorRefD ref_D{static_cast(args.D), LayoutC(int(config.ldd))}; + + if (kProvider == Provider::kReferenceHost) { + + cutlass::reference::host::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, + InnerProductOp + >( + config.problem_size, + *static_cast(args.alpha), + ref_A, + kTransformA, + ref_B, + kTransformB, + *static_cast(args.beta), + ref_C, + ref_D, + ElementAccumulator(), + ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), + args.batch_stride_A, + args.batch_stride_B, + args.batch_stride_C, + args.batch_stride_D + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + + cutlass::reference::device::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, + InnerProductOp + >( + config.problem_size, + *static_cast(args.alpha), + ref_A, + kTransformA, + ref_B, + kTransformB, + *static_cast(args.beta), + ref_C, + ref_D, + ElementAccumulator(), + ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), + args.batch_stride_A, + args.batch_stride_B, + args.batch_stride_C, + args.batch_stride_D + ); + + return Status::kSuccess; + } + + return Status::kErrorNotSupported; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename LayoutA_, + cutlass::ComplexTransform TransformA, + typename ElementB_, + typename LayoutB_, + cutlass::ComplexTransform TransformB, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new GemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, LayoutA_, TransformA, + ElementB_, LayoutB_, TransformB, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new GemmReferenceOperation< + Provider::kReferenceDevice, + ElementA_, LayoutA_, TransformA, + ElementB_, LayoutB_, TransformB, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >); +#endif +} + +/// Helper to create NN, NT, TN, and TT GEMM layouts. +template < + typename ElementA_, cutlass::ComplexTransform TransformA, + typename ElementB_, cutlass::ComplexTransform TransformB, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_canonical_layouts(Manifest &manifest) { + + // M Major outputs + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + // N Major outputs + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + + +/// Helper to create TN and interleaved layouts GEMM layouts. +template < + int InterleaveK, + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_interleaved_layouts(Manifest &manifest) { + + make_gemm< + ElementA_, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + +} + +/// Helper to real-valued GEMM with canonical layouts +template < + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_real_canonical_layouts(Manifest &manifest) { + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +// Helper to create all complex transformation permutations +template < + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_complex_canonical_layouts(Manifest &manifest) { + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kConjugate, + ElementB_, cutlass::ComplexTransform::kConjugate, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kConjugate, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kConjugate, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..01caa11e229ffd9109b0973dcca01064df448fa3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp @@ -0,0 +1,504 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/array.h" +#include "cutlass/array_subbyte.h" +#include "cutlass/library/library.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor +#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter +#include "cutlass/util/packed_stride.hpp" // make_cute_packed_stride +#include "gemm_operation_3x.hpp" +#include "library_internal.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cute/tensor.hpp" +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Limitation & Assumptions: +// 1. The tensor must be densely packed. That is, lda is k if the tensor is k-major, +// and lda is m if the tensor is m-major. +// 2. Circular buffer for tensorA and tensorE may have a less count compared to tensorB and others. +// This is because we can not get the problem_count information in the get_device_workspace_size(). +// But I can promise it will use at least 192MB memory if we enable circular buffer. +template +class SparseGemmUniversal3xOperation : public GemmOperation3xBase { +public: + + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementE = typename CollectiveMainloop::ElementE; + using LayoutE = typename CollectiveMainloop::LayoutE; + using SparseConfig = typename CollectiveMainloop::SparseConfig; + using LayoutATag = decltype(SparseConfig::deduce_layoutA_tag(typename CollectiveMainloop::LayoutA{})); + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutATag, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutATag, + SparseConfig, + typename Operator::ArchTag>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +public: + + /// Constructor + SparseGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) {} + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + CompressorUtility const& compressor_utility, + void* device_a_compressed_ptr = nullptr, + void* device_e_ptr = nullptr) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_E = static_cast(device_e_ptr); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.layout_a = compressor_utility.fill_layoutA_from_compressor(); + operator_args.mainloop.layout_e = compressor_utility.fill_layoutE_from_compressor(); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto problem_shape_MNKL = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + const int M = configuration->problem_size.m(); + const int N = configuration->problem_size.n(); + const int K = configuration->problem_size.k(); + const int L = configuration->batch_count; + using StrideA = typename CompressorUtility::StrideA; + auto dA = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + compressor_utility.set_problem_size(problem_shape_MNKL, dA); + auto status = update_arguments_(args, arguments, compressor_utility); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = problem_shape_MNKL; + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *) const override { + // Memory to hold operator + host_op_workspace_size = sizeof(Operator); + + // Memory to hold result of `.structure_sparse_zero_mask_fill()` + tensor_a_size = compressor_utility.get_raw_tensor_A_bytes(); + + // NOTE: order here is the order of workspace partition + const uint64_t size = host_op_workspace_size + tensor_a_size; + + return size; + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr), compressor_utility); + if (status != Status::kSuccess) { + return 0; + } + + typename Compressor::Arguments compress_arguments { + {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, + {/*Empty Not Use*/}, + {/*Empty Not Use*/} }; + + // Size for one iteration + // For multi-iteration, will need to multiply result of this function w/ actual problem_count + tensor_ac_size = compressor_utility.get_compressed_tensor_A_bytes(); + tensor_e_size = compressor_utility.get_tensor_E_bytes(); + device_op_workspace_size = Operator::get_workspace_size(args); + device_compress_workspace_size = Compressor::get_workspace_size(compress_arguments); + + // NOTE: order here is the order of workspace partition + device_per_iter_workspace_size = device_op_workspace_size + device_compress_workspace_size + tensor_ac_size + tensor_e_size; + + return device_per_iter_workspace_size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + return Status::kErrorInternal; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + + iter_idx.resize(static_cast(configuration)->device_count, 0); + + // Set problem_count. + problem_count = problem_count_from_profiler; + + // * Host Ptr + auto* host_op_workspace_ptr = reinterpret_cast(host_workspace); + auto* host_a_raw_ptr = host_op_workspace_ptr + host_op_workspace_size; + + // * Construct Op + Operator *op = new (host_op_workspace_ptr) Operator; + + // * Device Ptr (1st iteration) + // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | + // iteri : op_workspace | tensor_ac | tensor_e + auto* device_ptr_iter1 = static_cast(device_workspace); + auto* device_op_workspace_ptr_iter1 = device_ptr_iter1; + auto* device_compressor_workspace_ptr_iter1 = device_op_workspace_ptr_iter1 + device_op_workspace_size; + auto* device_a_compressed_ptr_iter1 = device_compressor_workspace_ptr_iter1 + device_compress_workspace_size; + auto* device_e_ptr_iter1 = device_a_compressed_ptr_iter1 + tensor_ac_size; + + // * Device A Raw Ptr + auto* device_a_raw_ptr = profiler_workspaces[0]; + + // * Random fill 50% of TensorA w/ zero following the structured sparse requirement + CUDA_CHECK(cudaMemcpyAsync(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost, stream)); + compressor_utility.structure_sparse_zero_mask_fill(host_a_raw_ptr, 2000); + CUDA_CHECK(cudaMemcpyAsync(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice, stream)); + + CUDA_CHECK(cudaGetLastError()); + + // * Compress DTensorA and get DTensorAC & DTensorE + cutlass::KernelHardwareInfo hw_info; + CUDA_CHECK(cudaGetDevice(&hw_info.device_id)); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, + {device_a_raw_ptr, + compressor_utility.dA, + device_a_compressed_ptr_iter1, + device_e_ptr_iter1}, + {hw_info} + }; + + cutlass::Status status {cutlass::Status::kSuccess }; + + Compressor compressor_op; + status = compressor_op.can_implement(arguments); + if (status != Status::kSuccess) { + return status; + } + + status = compressor_op.initialize(arguments, device_compressor_workspace_ptr_iter1, stream); + if (status != Status::kSuccess) { + return status; + } + + status = compressor_op.run(stream); + if (status != Status::kSuccess) { + return status; + } + + // * Copy Iter1's DTensorAC DTensorE to each iteration's DTensorAC DTensorE + for (int iter_i = 1; iter_i < problem_count; iter_i++) { + // * Device AC E Ptr per iteration + // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | + // iteri : op_workspace | tensor_ac | tensor_e + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_i; + auto* device_op_workspace_ptr = device_ptr_iteri; + auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; + auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; + auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; + + CUDA_CHECK(cudaMemcpyAsync(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice, stream)); + } + + CUDA_CHECK(cudaStreamSynchronize(stream)); + + CUDA_CHECK(cudaGetLastError()); + + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + + + const auto device_index = static_cast(arguments_ptr)->device_index; + + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_idx[device_index]; + auto* device_op_workspace_ptr = device_ptr_iteri; + auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; + auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; + auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; + iter_idx[device_index] = (iter_idx[device_index] + 1) % problem_count; + + Status status = update_arguments_(operator_args, static_cast(arguments_ptr), compressor_utility, device_a_compressed_ptr, device_e_ptr ); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(operator_args, device_op_workspace_ptr, stream, nullptr, + static_cast(arguments_ptr)->use_pdl); + return status; + } + +private: + // Variables that must change in the const functions. + mutable CompressorUtility compressor_utility; + mutable int problem_count = 1; + mutable std::vector iter_idx; + + mutable uint64_t tensor_ac_size = 0; + mutable uint64_t tensor_e_size = 0; + mutable uint64_t tensor_a_size = 0; + mutable uint64_t host_op_workspace_size = 0; + mutable uint64_t device_compress_workspace_size = 0; + mutable uint64_t device_op_workspace_size = 0; + mutable uint64_t device_per_iter_workspace_size = 0; +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..c95d238a81f825dbbeae689ec452467cc8ca3afa --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h @@ -0,0 +1,382 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Symm operation kinds (Symm, Hemm) + in CUTLASS Library. + + +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/symm.h" +#include "cutlass/gemm/kernel/default_symm_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class SymmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static SideMode const kSideModeA = Operator::kSideModeA; + static FillMode const kFillModeA = Operator::kFillModeA; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + SymmDescription description_; + +public: + + /// Constructor + SymmOperationBase(char const *name = "unknown_symm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.symm_kind = SymmKind::kUniversal; + description_.side_mode = kSideModeA; + description_.fill_mode = kFillModeA; + description_.blas_mode = kBlasMode; + + description_.kind = OperationKind::kSymm; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::SymmKernel::WarpCount::kM, + Operator::SymmKernel::WarpCount::kN, + Operator::SymmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the SYMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class SymmOperation : public SymmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static SideMode const kSideModeA = Operator::kSideModeA; + static FillMode const kFillModeA = Operator::kFillModeA; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + SymmOperation(char const *name = "unknown_symm"): + SymmOperationBase(name) { + + this->description_.symm_kind = SymmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + SymmConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + SymmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + SymmConfiguration const *configuration = + static_cast(configuration_ptr); + + SymmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + //std::cout << "initialize() library::SymmOperation" << std::endl; + //print_operator_args(args); + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + bool need_swapped_matrices = (kSideModeA == SideMode::kLeft && + std::is_same::value) || + (kSideModeA == SideMode::kRight && + std::is_same::value); + if (need_swapped_matrices) { + status = op->update(args.swapped_matrices(), device_workspace); + } else { + status = op->update(args, device_workspace); + } + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run() library::SymmOperation" << std::endl; + //print_operator_args(args); + status = op->run(stream); + + return status; + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "SymmOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " epilogue (alpha, beta): " + << operator_args.epilogue.alpha << ", " + << operator_args.epilogue.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ptr_A << ", {" + << operator_args.lda << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ptr_B << ", {" + << operator_args.ldb << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ptr_C << ", {" + << operator_args.ldc << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ptr_D << ", {" + << operator_args.ldd << "}" << std::endl; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..d419723791ace5d90eb7955223be9db72bbc2c3c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all TRMM operation kinds in CUTLASS Library. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/trmm.h" +#include "cutlass/gemm/kernel/default_trmm_universal.h" +#include "cutlass/gemm/kernel/trmm_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TrmmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + static SideMode const kSideMode = Operator::kSideMode; + static FillMode const kFillMode = Operator::kFillMode; + static DiagType const kDiagType = Operator::kDiagType; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + TrmmDescription description_; + +public: + + /// Constructor + TrmmOperationBase(char const *name = "unknown_trmm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kTrmm; + description_.trmm_kind = TrmmKind::kUniversal; + description_.side_mode = kSideMode; + description_.fill_mode = kFillMode; + description_.diag_type = kDiagType; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::TrmmKernel::WarpCount::kM, + Operator::TrmmKernel::WarpCount::kN, + Operator::TrmmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.D = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + } + + /// Returns the description of the TRMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TrmmOperation : public TrmmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + static SideMode const kSideMode = Operator::kSideMode; + static FillMode const kFillMode = Operator::kFillMode; + static DiagType const kDiagType = Operator::kDiagType; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + TrmmOperation(char const *name = "unknown_trmm"): + TrmmOperationBase(name) { + + this->description_.trmm_kind = TrmmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + TrmmConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + TrmmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.ptr_D = arguments->D; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + TrmmConfiguration const *configuration = + static_cast(configuration_ptr); + + TrmmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + bool need_swapped_matrices = (kSideMode == SideMode::kLeft && + std::is_same::value) || + (kSideMode == SideMode::kRight && + std::is_same::value); + if (need_swapped_matrices) { + status = op->update(args.swapped_matrices(), device_workspace); + } else { + status = op->update(args, device_workspace); + } + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..5d500d9149bf645eadf8110d98612c40882d742c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h @@ -0,0 +1,330 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Blockscale Gemm Profiler +*/ + + + +#pragma once + +#include +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class BlockScaledGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + // + // Methods + // + + /// Parses the problem + Status parse( + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::BlockScaledGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::BlockScaledGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::BlockScaledGemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::BlockScaledGemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *SFA{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *SFB{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + DeviceAllocation *Computed_SFD{nullptr}; + DeviceAllocation *Reference_SFD{nullptr}; + DeviceAllocation *Norm_constant{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::BlockScaledGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + cudaStream_t stream; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + GemmWorkspace gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + BlockScaledGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~BlockScaledGemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GemmWorkspace &gemm_workspace, + gemm::GemmCoord const &problem_shape, + std::array const &leading_dim, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration according to flexible user setups + void update_result_( + PerformanceResult &result, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + gemm::GemmCoord const &problem_shape, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..c110de278cac640c1cedd8dd29d1b8ac09de81ef --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Blockscale Gemm Profiler +*/ + + + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class BlockwiseGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + int64_t sf_vec_m{0}; + int64_t sf_vec_n{0}; + int64_t sf_vec_k{0}; + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + // + // Methods + // + + /// Parses the problem + Status parse( + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::BlockwiseGemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::BlockwiseGemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *SFA{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *SFB{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::BlockwiseGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + GemmWorkspace gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + BlockwiseGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~BlockwiseGemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..683465f50cda19c8d505f2e66bcb60173d7e942d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h @@ -0,0 +1,495 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for convolution + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/handle.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class Conv2dOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct Conv2dProblem { + + int64_t n, h, w, c, p, q, k, r, s; + int64_t groups; + int64_t pad_h, pad_w; + int64_t stride_h, stride_w; + int64_t dilation_h, dilation_w; + + std::vector alpha; + std::vector beta; + + library::SplitKMode split_k_mode; + int64_t split_k_slices; + + library::ConvModeID conv_mode; + + library::Provider eq_gemm_provider; + + // convolution with parallel interleaved reduction + // convolution epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (Conv2dProblem::alpha, Conv2dProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + // + // Methods + // + + /// Total number of bytes loaded + int64_t bytes(library::ConvDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::ConvDescription const &operation_desc) const; + + void set_default_output_size() { + p = ((h + pad_h - r * dilation_h) / stride_h) + 1; + q = ((w + pad_w - s * dilation_w) / stride_w) + 1; + } + + // Returns equivalent gemm problem size for convolution + cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c / groups)); + case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * h * w), int(c), int(k * r * s)); + case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(r * s * c), int(n * p * q)); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor A + std::vector extent_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(h), int(w), int(c)}; + case library::ConvKind::kDgrad: return {int(n), int(p), int(q), int(k)}; + case library::ConvKind::kWgrad: return {int(n), int(p), int(q), int(k)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor B + std::vector extent_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c / groups)}; + case library::ConvKind::kDgrad: return {int(k), int(r), int(s), int(c)}; + case library::ConvKind::kWgrad: return {int(n), int(h), int(w), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor C + std::vector extent_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(p), int(q), int(k)}; + case library::ConvKind::kDgrad: return {int(n), int(h), int(w), int(c)}; + case library::ConvKind::kWgrad: return {int(k), int(r), int(s), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix A + library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix B + library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix C + library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + // Gemm operator assumes column-major output + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix A + int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix B + int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix C + int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + + /// Workspace used + struct Conv2dWorkspace { + + /// Conv device allocations + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *reordered_B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + /// Library configuration and arguments for convolution operator + library::Conv2dConfiguration configuration; + library::ConvArguments arguments; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count; + + /// Buffer used for the cutlass conv2d operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// Host data buffers for host reference operation + /// host buffer for tensor + std::vector host_tensor_a; + + /// host buffer for tensor b + std::vector host_tensor_b; + + /// host buffer for tensor c + std::vector host_tensor_c; + + // + // Methods + // + + Conv2dWorkspace() + : A(nullptr), + B(nullptr), + reordered_B(nullptr), + C(nullptr), + Computed(nullptr), + Reference(nullptr) {} + + // Set stride vector for tensor activations, filters, output + void set_stride_vector(Conv2dProblem const &problem, + library::ConvKind const &conv_kind, + library::LayoutTypeID const &layout_a, + library::LayoutTypeID const &layout_b, + library::LayoutTypeID const &layout_c) { + std::vector stride_activations; + std::vector stride_filters; + std::vector stride_output; + + // Strides for interleaved fprop + if (conv_kind == library::ConvKind::kFprop && + ((layout_a == library::LayoutTypeID::kTensorNC32HW32 && + layout_b == library::LayoutTypeID::kTensorC32RSK32 && + layout_c == library::LayoutTypeID::kTensorNC32HW32) || + (layout_a == library::LayoutTypeID::kTensorNC64HW64 && + layout_b == library::LayoutTypeID::kTensorC64RSK64 && + layout_c == library::LayoutTypeID::kTensorNC64HW64))) { + int interleave = + (layout_a == library::LayoutTypeID::kTensorNC32HW32) ? 32 : 64; + + stride_activations.push_back(int(problem.w) * interleave); + stride_activations.push_back(int(problem.w) * int(problem.h) * + interleave); + stride_activations.push_back(int(problem.h) * int(problem.w) * + int(problem.c)); + + stride_filters.push_back(int(problem.k) * interleave); + stride_filters.push_back(int(problem.k) * int(problem.s) * interleave); + stride_filters.push_back(int(problem.k) * int(problem.s) * + int(problem.r) * interleave); + + stride_output.push_back(int(problem.q) * interleave); + stride_output.push_back(int(problem.q) * int(problem.p) * interleave); + stride_output.push_back(int(problem.q) * int(problem.p) * + int(problem.k)); + } else { + // Strides for the rest cases + stride_activations.push_back(int(problem.c)); + stride_activations.push_back(int(problem.w) * int(problem.c)); + stride_activations.push_back(int(problem.h) * int(problem.w) * + int(problem.c)); + + stride_filters.push_back(int(problem.c / problem.groups)); + stride_filters.push_back(int(problem.s) * int(problem.c / problem.groups)); + stride_filters.push_back(int(problem.r) * int(problem.s) * + int(problem.c / problem.groups)); + + stride_output.push_back(int(problem.k)); + stride_output.push_back(int(problem.q) * int(problem.k)); + stride_output.push_back(int(problem.q) * int(problem.p) * + int(problem.k)); + } + + switch (conv_kind) { + case library::ConvKind::kFprop: + configuration.stride_a = stride_activations; + configuration.stride_b = stride_filters; + configuration.stride_c = stride_output; + + break; + case library::ConvKind::kDgrad: + configuration.stride_a = stride_output; + configuration.stride_b = stride_filters; + configuration.stride_c = stride_activations; + + break; + case library::ConvKind::kWgrad: + configuration.stride_a = stride_output; + configuration.stride_b = stride_activations; + configuration.stride_c = stride_filters; + + break; + default: + throw std::runtime_error( + "Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + +protected: + + // + // Data members + // + + /// CONV problem obtained from problem space + Conv2dProblem problem_; + + /// Device memory allocations + Conv2dWorkspace conv_workspace_; + + /// CUTLASS parallel reduction operation to follow this* conv2d operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + Conv2dOperationProfiler(Options const &options); + + /// Destructor + virtual ~Conv2dOperationProfiler(); + + Conv2dProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::ConvDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against host reference + bool verify_with_host_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against device reference + bool verify_with_device_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#if CUTLASS_ENABLE_CUDNN + + /// Verifies CUTLASS against cudnn reference + bool verify_with_cudnn_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#endif //#if CUTLASS_ENABLE_CUDNN + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ac4abdef238b00f216053419620a60dfccfd5316 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h @@ -0,0 +1,449 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for convolution + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/handle.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class Conv3dOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct Conv3dProblem { + + int64_t n, d, h, w, c, z, p, q, k, t, r, s; + int64_t pad_d, pad_h, pad_w; + int64_t stride_d, stride_h, stride_w; + int64_t dilation_d, dilation_h, dilation_w; + + std::vector alpha; + std::vector beta; + + library::SplitKMode split_k_mode; + int64_t split_k_slices; + + library::ConvModeID conv_mode; + + library::Provider eq_gemm_provider; + + // convolution with parallel interleaved reduction + // convolution epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (Conv3dProblem::alpha, Conv3dProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + // + // Methods + // + + /// Total number of bytes loaded + int64_t bytes(library::ConvDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::ConvDescription const &operation_desc) const; + + /// Infers output size from the input size, padding, stride, and dilation + void set_default_output_size() { + z = ((d + pad_d - t * dilation_d) / stride_d) + 1; + p = ((h + pad_h - r * dilation_h) / stride_h) + 1; + q = ((w + pad_w - s * dilation_w) / stride_w) + 1; + } + + // Returns equivalent gemm problem size for convolution + cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * z * p * q), int(k), int(t * r * s * c)); + case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * d * h * w), int(c), int(t * r * s * k)); + case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(t * r * s * c), int(n * z * p * q)); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor A + std::vector extent_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(d), int(h), int(w), int(c)}; + case library::ConvKind::kDgrad: return {int(n), int(z), int(p), int(q), int(k)}; + case library::ConvKind::kWgrad: return {int(n), int(z), int(p), int(q), int(k)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor B + std::vector extent_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(k), int(t), int(r), int(s), int(c)}; + case library::ConvKind::kDgrad: return {int(k), int(t), int(r), int(s), int(c)}; + case library::ConvKind::kWgrad: return {int(n), int(d), int(h), int(w), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor C + std::vector extent_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(z), int(p), int(q), int(k)}; + case library::ConvKind::kDgrad: return {int(n), int(d), int(h), int(w), int(c)}; + case library::ConvKind::kWgrad: return {int(k), int(t), int(r), int(s), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix A + library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix B + library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix C + library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + // Gemm operator assumes column-major output + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix A + int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix B + int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix C + int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + + /// Workspace used + struct Conv2dWorkspace { + + /// Conv device allocations + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + /// Library configuration and arguments for convolution operator + library::Conv3dConfiguration configuration; + library::ConvArguments arguments; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count; + + /// Buffer used for the cutlass conv2d operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// Host data buffers for host reference operation + /// host buffer for tensor + std::vector host_tensor_a; + + /// host buffer for tensor b + std::vector host_tensor_b; + + /// host buffer for tensor c + std::vector host_tensor_c; + + + // + // Methods + // + + Conv2dWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + + // Returns stride vector for tensor A + std::vector stride_a(library::ConvKind const &conv_kind) { + return { + configuration.layout_a(conv_kind).stride()[0], + configuration.layout_a(conv_kind).stride()[1], + configuration.layout_a(conv_kind).stride()[2], + configuration.layout_a(conv_kind).stride()[3] + }; + } + + // Returns stride vector for tensor B + std::vector stride_b(library::ConvKind const &conv_kind) { + + return { + configuration.layout_b(conv_kind).stride()[0], + configuration.layout_b(conv_kind).stride()[1], + configuration.layout_b(conv_kind).stride()[2], + configuration.layout_b(conv_kind).stride()[3] + }; + } + + // Returns stride vector for tensor C + std::vector stride_c(library::ConvKind const &conv_kind) { + + return { + configuration.layout_c(conv_kind).stride()[0], + configuration.layout_c(conv_kind).stride()[1], + configuration.layout_c(conv_kind).stride()[2], + configuration.layout_c(conv_kind).stride()[3] + }; + } + }; + +protected: + + // + // Data members + // + + /// CONV problem obtained from problem space + Conv3dProblem problem_; + + /// Device memory allocations + Conv2dWorkspace conv_workspace_; + + /// CUTLASS parallel reduction operation to follow this* conv2d operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + Conv3dOperationProfiler(Options const &options); + + /// Destructor + virtual ~Conv3dOperationProfiler(); + + Conv3dProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Updates the arguments structure for the CUTLASS operator based on + /// the problem index. + void set_cutlass_operator_arguments_(int problem_idx = 0); + + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::ConvDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against host reference + bool verify_with_host_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against device reference + bool verify_with_device_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#if CUTLASS_ENABLE_CUDNN + + /// Verifies CUTLASS against cudnn reference + bool verify_with_cudnn_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#endif //#if CUTLASS_ENABLE_CUDNN + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..873ba1abe03c05df29edc032ea3f1ffd2f19c3ee --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h @@ -0,0 +1,456 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Helper functions for mapping CUTLASS concepts to cuBLAS. +*/ + +#pragma once + +#if CUTLASS_ENABLE_CUBLAS +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/blas3.h" + +#include "options.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Converts a cuBLAS status to cutlass::Status +Status get_cutlass_status(cublasStatus_t cublas); + +/// Converts a cuBLAS status to cutlass::profiler::Disposition +Disposition get_cutlass_disposition(cublasStatus_t cublas_status); + +/// Maps a CUTLASS tensor layout to a cuBLAS transpose operation +bool get_cublas_transpose_operation( + cublasOperation_t &operation, + library::LayoutTypeID layout, + library::ComplexTransform transform = library::ComplexTransform::kNone); + +/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration +bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type); + +/// Gets the cublas algorithm given threadblock tile dimensions and math opcode class +cublasGemmAlgo_t get_cublas_gemm_algo( + int cta_m, + int cta_n, + int cta_k, + library::OpcodeClassID opcode_class); + +/// Returns a status if cuBLAS can satisfy a particular GEMM description +Status cublas_satisfies(library::GemmDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular RankK description +Status cublas_satisfies(library::RankKDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular TRMM description +Status cublas_satisfies(library::TrmmDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular SYMM/HEMM description +Status cublas_satisfies(library::SymmDescription const &desc); + +/// This is a helper class to create cublasHandle_t automatically on CublasCreate object creation and +/// to destroy cublasHandle_t on CublasCreate object destruction. +/// Additionally, it provides implicit cast from CublasCreate's object to cublasHandle_t's object +class CublasCreate { +private: + cublasHandle_t handle; + cublasStatus_t status; + +public: + CublasCreate() { + status = cublasCreate(&handle); + } + + ~CublasCreate() { + cublasDestroy(handle); + } + + /// Implicit cast CublasCreate object to cublasHandle_t + operator cublasHandle_t() const { return handle; } + + /// returns cublasStatus_t for handle creation + cublasStatus_t get_cublas_create_status() { return status; } +}; + +/// This is a helper class to create cublasLtHandle_t automatically on CublasLtCreate object creation and +/// to destroy cublasLtHandle_t on CublasLtCreate object destruction. +/// Additionally, it provides implicit cast from CublasLtCreate's object to cublasLtHandle_t's object +class CublasLtCreate { +private: + cublasLtHandle_t handle; + cublasStatus_t status; + +public: + CublasLtCreate() { + status = cublasLtCreate(&handle); + } + + ~CublasLtCreate() { + cublasLtDestroy(handle); + } + + /// Implicit cast CublasLtCreate object to cublasLtHandle_t + operator cublasLtHandle_t() const { return handle; } + + /// returns cublasLtStatus_t for handle creation + cublasStatus_t get_cublaslt_create_status() { return status; } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Selects one or more cuBLAS algorithms. +static void select_cublas_algorithms( + std::vector &algorithms, + Options const &options, + library::GemmDescription const &op_desc) { + + library::OpcodeClassID const & opcode_class = + op_desc.tile_description.math_instruction.opcode_class; + + switch (options.library.algorithm_mode) { + case AlgorithmMode::kMatching: + { + algorithms.push_back(get_cublas_gemm_algo( + op_desc.tile_description.threadblock_shape.m(), + op_desc.tile_description.threadblock_shape.n(), + op_desc.tile_description.threadblock_shape.k(), + opcode_class)); + break; + } + + case AlgorithmMode::kBest: + { + // Choose first enumerated mode. If none are enumerated, choose based on opcode class + // and evaluate all of them. + + if (options.library.algorithms.empty()) { + // Enumerate all algorithms + if (opcode_class == library::OpcodeClassID::kSimt) { + + for (int algo = CUBLAS_GEMM_DEFAULT; + algo <= CUBLAS_GEMM_ALGO23; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + else { + + for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + } + else { + // Use the listed algorithms + algorithms.reserve(options.library.algorithms.size()); + + for (int algo : options.library.algorithms) { + algorithms.push_back(reinterpret_cast(algo)); + } + } + + break; + } + + case AlgorithmMode::kDefault: + { + + // Use the library's default algorithm + algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? + CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + break; + } + default: + { + break; + } + } +} + +/// Dispatcher to cublasGemmEx() +struct cublasGemmExDispatcher { + + // + // Data members + // + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasOperation_t trans_B; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + cublasGemmAlgo_t algo; + Status status; + + // + // Methods + // + + cublasGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmUniversalConfiguration configuration_, + library::GemmUniversalArguments arguments_, + cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT + ); + + /// Executes GEMM using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/// Dispatcher to cublaslt kernels +// +struct cublasLtGemmExDispatcher { + + // + // Data members + // + library::GemmDescription const &op_desc; + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasOperation_t trans_B; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type = CUDA_R_32F; + + //cublasLt-specific data structures + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL; + cublasLtMatmulPreference_t preference = NULL; + + //is set by call to get_cublaslt_algo() + cublasLtMatmulHeuristicResult_t heuristicResult_; + void *workspace = nullptr; + + Status status; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + // + // Methods + // + + cublasLtGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmUniversalConfiguration configuration_, + library::GemmUniversalArguments arguments_ + ); + + /// Initialize the cublasLt variables + void initialize_cublaslt(); + + + /// Runs auto-tuning for the cublas heuristics + bool get_cublaslt_algo(cublasLtHandle_t handle, + AlgorithmMode algorithm_mode + ); + + /// Executes GEMM using these arguments + cublasStatus_t operator()(cublasLtHandle_t handle, cudaStream_t stream = nullptr); + + ~cublasLtGemmExDispatcher(){ + + // descriptors are no longer needed as all GPU work was already enqueued + if (preference) cublasLtMatmulPreferenceDestroy(preference); + if (Ddesc) cublasLtMatrixLayoutDestroy(Ddesc); + if (Cdesc) cublasLtMatrixLayoutDestroy(Cdesc); + if (Bdesc) cublasLtMatrixLayoutDestroy(Bdesc); + if (Adesc) cublasLtMatrixLayoutDestroy(Adesc); + if (operationDesc) cublasLtMatmulDescDestroy(operationDesc); + + if (workspace) { + cudaFree(workspace); + } + + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublas rank k update kernels +struct cublasRankKDispatcher { + + // + // Data members + // + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasFillMode_t uplo; + cudaDataType_t data_type_A; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + int num_ranks; //(rank-k or rank-2k) + BlasMode blas_mode; //(symmetric or hermitian) + Status status; + + // + // Methods + // + + cublasRankKDispatcher( + library::RankKDescription const &op_desc, + library::RankKConfiguration configuration_, + library::RankKArguments arguments_ + ); + + /// Executes RankK using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublasTrmm() +struct cublasTrmmDispatcher { + + // + // Data members + // + library::TrmmConfiguration configuration; + library::TrmmArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasSideMode_t side; + cublasFillMode_t uplo; + cublasDiagType_t diag; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_D; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + Status status; + + // + // Methods + // + + cublasTrmmDispatcher( + library::TrmmDescription const &op_desc, + library::TrmmConfiguration configuration_, + library::TrmmArguments arguments_ + ); + + /// Executes TRMM using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublas symm/hemm update kernels +struct cublasSymmDispatcher { + + // + // Data members + // + library::SymmConfiguration configuration; + library::SymmArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasSideMode_t side; + cublasFillMode_t uplo; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + BlasMode blas_mode; //(symmetric or hermitian) + Status status; + + // + // Methods + // + + cublasSymmDispatcher( + library::SymmDescription const &op_desc, + library::SymmConfiguration configuration_, + library::SymmArguments arguments_ + ); + + /// Executes Symm using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +} // namespace profiler +} // namespace cutlass + + +#endif // #if CUTLASS_ENABLE_CUBLAS diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..7ce9eea5a883fa4c5732f5d8aec120a99064bac0 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h @@ -0,0 +1,590 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Helper functions for mapping CUTLASS concepts to cuDNN. + +*/ + +#pragma once +#if CUTLASS_ENABLE_CUDNN +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/library/library.h" +#include "enumerated_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Converts a cuDNN status to cutlass::Status +Status get_cutlass_status(cudnnStatus_t cudnn_status); + +/// Converts a cuDNN status to cutlass::profiler::Disposition +Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status); + +/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception +Status checkCudnnErr(cudnnStatus_t cudnn_status); + +/// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration +bool get_cudnn_conv_mode(cudnnConvolutionMode_t &cudnn_conv_mode, conv::Mode conv_mode); + +/// Maps a CUTLASS layout type to a cuDNN data type enumeration +bool get_cudnn_layout(cudnnTensorFormat_t &cudnn_layout, library::LayoutTypeID layout); + +/// Maps a CUTLASS numeric type to a cuDNN data type enumeration +bool get_cudnn_datatype(cudnnDataType_t &cudnn_element_type, library::NumericTypeID element_type); + +/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type +bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescription const &conv_desc); + +/// Returns a status if cudnn can satisfy a particular Conv2d description +Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv2dConfiguration const &configuration); + +/// Returns a status if cudnn can satisfy a particular Conv3d description +Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv3dConfiguration const &configuration); + +/// Cudnn compute type seems to be hardcoded to float (To handle a possible cudnn issue) +float cast_cudnn_compute_type_to_float(library::NumericTypeID type, void const * src); + + +/// This is a helper class to create cudnnHandle_t automatically on CudnnCreate object creation and +/// to destroy cudnnHandle_t on CudnnCreate object destruction. +/// Additionally, it provides implicit cast from CudnnCreate's object to cudnnHandle_t's object +class CudnnCreate { +private: + cudnnHandle_t handle; + cudnnStatus_t status; + +public: + CudnnCreate() { + status = cudnnCreate(&handle); + } + + ~CudnnCreate() { + cudnnDestroy(handle); + } + + /// Implicit cast CudnnCreate object to cudnnHandle_t + operator cudnnHandle_t() const { return handle; } + + /// returns cudnnStatus_t for handle creation + cudnnStatus_t get_cudnn_create_status() { return status; } +}; + + +namespace detail { + +/// Dispatcher to cudnn convolution operators +struct cudnnConvDispatcher { + + // + // Data members + // + //library::Conv2dConfiguration configuration; + library::ConvArguments arguments; + library::ConvKind conv_kind; + + // cudnn-specific data structures to fill cudnn API call arguments + // cudnn activation, filter, and output descriptors + cudnnTensorDescriptor_t activation_desc; + cudnnFilterDescriptor_t filter_desc; + cudnnTensorDescriptor_t output_desc; + cudnnConvolutionDescriptor_t conv_desc; + + // cudnn datatypes + cudnnDataType_t data_type_activation; + cudnnDataType_t data_type_filter; + cudnnDataType_t data_type_output; + + // cudnn layouts + cudnnTensorFormat_t layout_activation; + cudnnTensorFormat_t layout_filter; + cudnnTensorFormat_t layout_output; + + // cudnn convolution mode + cudnnConvolutionMode_t conv_mode; + + // cudnn math type (tensorop, tensorop with conversion, simt) + cudnnMathType_t math_type; + + // cudnn compute data type + cudnnDataType_t compute_type; + + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + float alpha; + float beta; + + // cudnn workspace + size_t workspace_size_in_bytes = 0; + cutlass::device_memory::allocation workspace; + + // select cudnn's implicit gemm precomputed algorithm with tensor operations + static cudnnConvolutionFwdAlgo_t const fprop_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + static cudnnConvolutionBwdDataAlgo_t const dgrad_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + static cudnnConvolutionBwdFilterAlgo_t const wgrad_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + + Status status; + + // + // Methods + // + + // TODO: unify ctor cudnnConvDispatcher for conv2d and conv3d by unifying Conv2dConfiguration + + // ctor for conv2d + cudnnConvDispatcher( + library::ConvDescription const &op_desc, + library::Conv2dConfiguration configuration, + library::ConvArguments arguments_, + cudnnHandle_t handle + ): + //configuration(configuration_), + arguments(arguments_), + conv_kind(op_desc.conv_kind), + status(Status::kSuccess) { + + bool good = true; + + // Get cudnn datatype, layout, and convolution mode from library::ConvDescription + good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); + good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); + good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); + good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); + good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); + good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); + good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); + // Get cudnn mathtype (cudnnMathType_t) + good = (good && get_cudnn_mathtype(math_type, op_desc)); + good = (good && get_cudnn_datatype( + compute_type, + op_desc.tile_description.math_instruction.element_accumulator)); + // Check cutlass Conv2d description has equivalent operator in cudnn + if (!good) { + status = Status::kErrorNotSupported; + return; + } + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); + beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); + + // Create convolution descriptor object + status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); + + // Configure convolution operator + std::vector padding {configuration.problem_size.pad_h, configuration.problem_size.pad_w}; + std::vector stride {configuration.problem_size.stride_h, configuration.problem_size.stride_w}; + std::vector dilation {configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; + + status = get_cutlass_status( + cudnnSetConvolutionNdDescriptor( + conv_desc, + op_desc.conv_dim, + padding.data(), + stride.data(), + dilation.data(), + conv_mode, + compute_type + )); + + // Set groups + status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); + + // Create activation, filter, and output descriptor objects + status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); + status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); + status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); + + // Set activation, filter, and output descriptor + status = get_cutlass_status( + cudnnSetTensor4dDescriptor( + activation_desc, + layout_activation, + data_type_activation, + configuration.problem_size.N, + configuration.problem_size.C, + configuration.problem_size.H, + configuration.problem_size.W + )); + + status = get_cutlass_status( + cudnnSetFilter4dDescriptor( + filter_desc, + data_type_filter, + layout_filter, + configuration.problem_size.K, + configuration.problem_size.C / configuration.problem_size.groups, + configuration.problem_size.R, + configuration.problem_size.S + )); + + status = get_cutlass_status( + cudnnSetTensor4dDescriptor( + output_desc, + layout_output, + data_type_output, + configuration.problem_size.N, + configuration.problem_size.K, + configuration.problem_size.P, + configuration.problem_size.Q + )); + + // Set math instruction to tensor op + status = get_cutlass_status( + cudnnSetConvolutionMathType(conv_desc, math_type)); + + // Initialize workspace + switch (conv_kind) { + case library::ConvKind::kFprop: + status = get_cutlass_status( + cudnnGetConvolutionForwardWorkspaceSize( + handle, + activation_desc, + filter_desc, + conv_desc, + output_desc, + fprop_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kDgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, + filter_desc, + output_desc, + conv_desc, + activation_desc, + dgrad_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kWgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, + activation_desc, + output_desc, + conv_desc, + filter_desc, + wgrad_algo, + &workspace_size_in_bytes + )); break; + + } + + workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); + } + + + // ctor for conv3d + cudnnConvDispatcher( + library::ConvDescription const &op_desc, + library::Conv3dConfiguration configuration, + library::ConvArguments arguments_, + cudnnHandle_t handle + ): + //configuration(configuration_), + arguments(arguments_), + conv_kind(op_desc.conv_kind), + status(Status::kSuccess) { + + bool good = true; + + // Get cudnn datatype, layout, and convolution mode from library::ConvDescription + good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); + good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); + good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); + + good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); + good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); + good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); + + good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); + + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); + beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); + + good = (good && get_cudnn_datatype( + compute_type, + op_desc.tile_description.math_instruction.element_accumulator)); + + // Check cutlass Conv2d description has equivalent operator in cudnn + if (!good) { + status = Status::kErrorNotSupported; + } + + // Create convolution descriptor object + status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); + + // Configure convolution operator + std::vector padding {configuration.problem_size.pad_d, configuration.problem_size.pad_h, configuration.problem_size.pad_w}; + std::vector stride {configuration.problem_size.stride_d, configuration.problem_size.stride_h, configuration.problem_size.stride_w}; + std::vector dilation {configuration.problem_size.dilation_d, configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; + + status = get_cutlass_status( + cudnnSetConvolutionNdDescriptor( + conv_desc, + op_desc.conv_dim, + padding.data(), + stride.data(), + dilation.data(), + conv_mode, + compute_type + )); + + // Set groups + status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); + + // Create activation, filter, and output descriptor objects + status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); + status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); + status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); + + // Set activation descriptor + std::vector activation_extent { + configuration.problem_size.N, + configuration.problem_size.C, + configuration.problem_size.D, + configuration.problem_size.H, + configuration.problem_size.W + }; + + std::vector activation_stride { + configuration.layout_activations.stride()[3], + 1, + configuration.layout_activations.stride()[2], + configuration.layout_activations.stride()[1], + configuration.layout_activations.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetTensorNdDescriptor( + activation_desc, + data_type_activation, + op_desc.conv_dim + 2, + activation_extent.data(), + activation_stride.data() + )); + + // Set filter descriptor + std::vector filter_extent { + configuration.problem_size.K, + configuration.problem_size.C, + configuration.problem_size.T, + configuration.problem_size.R, + configuration.problem_size.S + }; + + std::vector filter_stride { + configuration.layout_filters.stride()[3], + 1, + configuration.layout_filters.stride()[2], + configuration.layout_filters.stride()[1], + configuration.layout_filters.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetFilterNdDescriptor( + filter_desc, + data_type_filter, + layout_filter, + op_desc.conv_dim + 2, + filter_extent.data() + )); + + + // Set output descriptor + std::vector output_extent { + configuration.problem_size.N, + configuration.problem_size.K, + configuration.problem_size.Z, + configuration.problem_size.P, + configuration.problem_size.Q + }; + + std::vector output_stride { + configuration.layout_output.stride()[3], + 1, + configuration.layout_output.stride()[2], + configuration.layout_output.stride()[1], + configuration.layout_output.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetTensorNdDescriptor( + output_desc, + data_type_output, + op_desc.conv_dim + 2, + output_extent.data(), + output_stride.data() + )); + + // Set math instruction to tensor op + status = get_cutlass_status( + cudnnSetConvolutionMathType(conv_desc, math_type)); + + // Initialize workspace + switch (conv_kind) { + case library::ConvKind::kFprop: + status = get_cutlass_status( + cudnnGetConvolutionForwardWorkspaceSize( + handle, + activation_desc, + filter_desc, + conv_desc, + output_desc, + fprop_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kDgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, + filter_desc, + output_desc, + conv_desc, + activation_desc, + dgrad_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kWgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, + activation_desc, + output_desc, + conv_desc, + filter_desc, + wgrad_algo, + &workspace_size_in_bytes + )); break; + + } + + workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); + } + + /// Executes Conv2d operator from cudnn library + cudnnStatus_t operator()(cudnnHandle_t handle) { + + switch (conv_kind) { + case library::ConvKind::kFprop: + return cudnnConvolutionForward( + handle, + &alpha, + activation_desc, + activation(), + filter_desc, + filter(), + conv_desc, + fprop_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + output_desc, + arguments.D + ); + case library::ConvKind::kDgrad: + return cudnnConvolutionBackwardData( + handle, + &alpha, + filter_desc, + filter(), + output_desc, + output(), + conv_desc, + dgrad_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + activation_desc, + arguments.D + ); + case library::ConvKind::kWgrad: + return cudnnConvolutionBackwardFilter( + handle, + &alpha, + activation_desc, + activation(), + output_desc, + output(), + conv_desc, + wgrad_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + filter_desc, + arguments.D + ); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Activation Tensor + void const * activation() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.A; + case library::ConvKind::kDgrad : return arguments.C; + case library::ConvKind::kWgrad : return arguments.B; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Filter Tensor + void const *filter() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.B; + case library::ConvKind::kDgrad : return arguments.B; + case library::ConvKind::kWgrad : return arguments.C; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Output Tensor + void const *output() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.C; + case library::ConvKind::kDgrad : return arguments.A; + case library::ConvKind::kWgrad : return arguments.A; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } +}; + +} // namespace detail +///////////////////////////////////////////////////////////////////////////////////////////////// +#endif //#if CUTLASS_ENABLE_CUDNN +} // namespace profiler +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..be82245325cebb147e2c801965a52ece91395cb2 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment +*/ + +#pragma once +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +#include "options.h" +#include "operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// CUTLASS Profiler application +class CutlassProfiler { +private: + + // + // Data members + // + + /// Performance testbench options + Options options_; + + /// Entry points for each operation + OperationProfilerVector operation_profilers_; + +private: + + /// Prints usage + void print_usage_(std::ostream &); + + /// Prints usage + void print_options_(std::ostream &); + + /// Enumerates all operations + void enumerate_(); + + /// Profiles all operations + int profile_(); + +public: + + CutlassProfiler(Options const &options); + ~CutlassProfiler(); + + /// Invokes profiling operations + int operator()(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..98f1fdc3044501e456c927471b30d74b09eafd39 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h @@ -0,0 +1,56 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief +*/ + +#pragma once + +#include + +//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } +//#define report(x) {} + +// Enable/Disable Profiler debug prints +//#define DEBUG_PROFILER + +//RED 31m // profiler prints debug messages in red +//YELLOW 33m // ir prints debug messages in yellow + +#ifndef DEBUG_PROFILER +#define debugprof(...) +#else +#define debugprof(...) do { \ + printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ + printf(__VA_ARGS__); \ + printf("\033[0m\n"); \ + } while (0) +#endif diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h new file mode 100644 index 0000000000000000000000000000000000000000..488b635c2ec233e3027303bbf15a34f375a438fd --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/library/library.h" +#include "cutlass/util/distribution.h" + +#include "enumerated_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Device memory allocation +class DeviceAllocation { +private: + + /// Data type of contained elements + library::NumericTypeID type_; + + /// Gets the stride between elements + size_t batch_stride_; + + /// Capacity in elements of device allocation + size_t capacity_; + + /// Pointer to device memory + void *pointer_; + + /// Layout type ID + library::LayoutTypeID layout_; + + /// Stride vector + std::vector stride_; + + /// Extent vector + std::vector extent_; + + /// Support allocating a 'batch' of non-overlapping tensors in contiguous memory + int batch_count_; + + /// Buffer holding TensorRef instance to recently allocated memory + std::vector tensor_ref_buffer_; + + /// The device ID where the allocation is made + int device_; + +public: + // + // Static member functions + // + + /// Determines the number of bytes needed to represent this numeric type + static size_t bytes(library::NumericTypeID type, size_t capacity); + + /// Returns the stride of a packed layout + static std::vector get_packed_layout( + library::LayoutTypeID layout_id, + std::vector const &extent); + + /// returns the capacity needed + static size_t construct_layout( + void *bytes, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector &stride); + + /// Returns true if two blocks have exactly the same value + static bool block_compare_equal( + library::NumericTypeID numeric_type, + void const *ptr_A, + void const *ptr_B, + size_t capacity); + + /// Returns true if two blocks have approximately the same value + static bool block_compare_relatively_equal( + library::NumericTypeID numeric_type, + void const *ptr_A, + void const *ptr_B, + size_t capacity, + double epsilon, + double nonzero_floor); + +public: + // + // Methods + // + + DeviceAllocation(); + + DeviceAllocation( + library::NumericTypeID type, + size_t capacity, + int device = -1); + + DeviceAllocation( + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride = std::vector(), + int batch_count = 1, + int device = -1); + + ~DeviceAllocation(); + + DeviceAllocation &reset(); + + /// Allocates device memory of a given type and capacity + DeviceAllocation &reset(library::NumericTypeID type, size_t capacity); + + /// Allocates memory for a given layout and tensor + DeviceAllocation &reset( + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride = std::vector(), + int batch_count = 1); + + /// Returns a buffer owning the tensor reference + std::vector &tensor_ref() { + return tensor_ref_buffer_; + } + + bool good() const; + + /// Data type of contained elements + library::NumericTypeID type() const; + + /// Pointer to start of device memory allocation + void *data() const; + + /// Pointer to the first element of a batch + void *batch_data(int batch_idx) const; + + /// Gets the layout type + library::LayoutTypeID layout() const; + + /// Gets the stride vector + std::vector const & stride() const; + + /// Gets the extent vector + std::vector const & extent() const; + + /// Gets the number of adjacent tensors in memory + int batch_count() const; + + /// Gets the stride (in units of elements) between items + int64_t batch_stride() const; + + /// Gets the stride (in units of bytes) between items + int64_t batch_stride_bytes() const; + + /// Capacity of allocation in number of elements + size_t capacity() const; + + /// Capacity of allocation in bytes + size_t bytes() const; + + /// Initializes a device allocation to a random distribution using cuRAND + void initialize_random_device(int seed, Distribution dist); + + /// Initializes a host allocation to a random distribution using std::cout + void initialize_random_host(int seed, Distribution dist); + + /// Initializes a device allocation to a sequential distribution + void initialize_sequential_device(Distribution dist); + + /// Initializes a host allocation to a sequential distribution + void initialize_sequential_host(Distribution dist); + + /// Initializes a device allocation to a random distribution using cuRAND + void initialize_random_sparsemeta_device(int seed, int MetaSizeInBits); + + /// Initializes a host allocation to a random distribution using std::cout + void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits); + + /// Uniformly fills a tensor with a value when provided o.w. zero + void fill_device(double value); + + /// Uniformly fills a host allocation with a value when provided o.w. zero + void fill_host(double value); + + /// Copies from an equivalent-sized tensor in device memory + void copy_from_device(void const *ptr); + + /// Copies from an equivalent-sized tensor in device memory + void copy_from_host(void const *ptr); + + /// Copies from an equivalent-sized tensor in device memory + void copy_to_host(void *ptr); + + /// Writes a tensor to csv + void write_tensor_csv(std::ostream &out); + +private: + /// A wrapper that sets the device, performs malloc, and sets back + cudaError_t malloc(void** ptr, size_t size); +}; + +using DeviceAllocationList = std::list; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h new file mode 100644 index 0000000000000000000000000000000000000000..0443b340397426bfafc812c1a4b9179fc6af0de4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief +*/ + +#pragma once + +#include +#include + + +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" + +#include "options.h" +#include "device_allocation.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Collection of allocations on the device +class DeviceContext { +public: + + // + // Type definitions + // + using AllocationMap = std::map; + +private: + // + // Data members + // + + /// Memory allocations that exist (owning) + DeviceAllocationList device_memory_; + + /// Non-owning set of named allocations + AllocationMap allocations_; + +public: + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_block( + Options const &options, + std::string const &name, + library::NumericTypeID type, + size_t capacity, + size_t device_index); + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride, + int batch_count, + size_t device_index); + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_and_initialize_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride, + int batch_count, + int seed_shift, + size_t device_index); + + /// Allocates memory for sparse meta data + DeviceAllocation *allocate_and_initialize_sparsemeta_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + library::NumericTypeID type_a, + std::vector const &extent, + std::vector const &stride, + int batch_count, + int seed_shift, + size_t device_index); + + /// Clears named allocations (but does not necessarily free memory) + void clear(); + + /// Frees all device memory allocations + void free(); + + /// Gets the allocation by name + DeviceAllocation &at(std::string const &name); + + size_t size() const; + + AllocationMap::iterator begin(); + AllocationMap::iterator end(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h new file mode 100644 index 0000000000000000000000000000000000000000..897311c228ce76c4e8814ce996929561d44d2465 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h @@ -0,0 +1,169 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +#include +#include +#include +#include +#include "cutlass/library/library.h" + +#define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +T from_string(std::string const &); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumerated type describing how the performance testbench evaluates kernels. +enum class ExecutionMode { + kProfile, ///< regular verification and profiling + kDryRun, ///< no kernels are launched or workspaces allocated; used to assess what operators might be launched + kEnumerate, ///< no kernels launched or workspaces allocated; lists all operation kind and operations + kTrace, ///< executes a single device-side computation with no other kernel launches + kInvalid +}; + +/// Converts a ExecutionMode enumerant to a string +char const *to_string(ExecutionMode mode, bool pretty = false); + +/// Parses a ExecutionMode enumerant from a string +template <> +ExecutionMode from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Library algorithm mode +enum class AlgorithmMode { + kMatching, ///< compare against best matching algorithm + kBest, ///< evaluate all library algorithms and report best + kDefault, ///< use the library's default algorithm option + kInvalid +}; + +/// Converts a ExecutionMode enumerant to a string +char const *to_string(AlgorithmMode mode, bool pretty = false); + +/// Parses a ExecutionMode enumerant from a string +template <> +AlgorithmMode from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Outcome of a performance test +enum class Disposition { + kPassed, + kFailed, // kernel itself reported an error + kNotRun, + kIncorrect, // kernel finished without a detected error, but result does not equal expected result + kNotVerified, + kInvalidProblem, + kNotSupported, + kInvalid +}; + +/// Converts a Disposition enumerant to a string +char const *to_string(Disposition disposition, bool pretty = false); + +/// Parses a Disposition enumerant from a string +template <> +Disposition from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Indicates when to save +enum class SaveWorkspace { + kNever, + kIncorrect, + kAlways, + kInvalid +}; + +/// Converts a SaveWorkspace enumerant to a string +char const *to_string(SaveWorkspace save_option, bool pretty = false); + +/// Parses a SaveWorkspace enumerant from a string +template <> +SaveWorkspace from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Indicates the type of kernel argument +// ArgumentType can be both ScalarType or NumericType. Thus, enums kScalar and kNumeric +// 1) kScalar: e.g. of a Scalar ArgumentType is u32 is a Scalar type. +// Its c++ equivalent as "type name = initializer" is "u32 m = 32" +// 2) kNumeric: e.g. of a Numeric ArgumentType is NumericTypeID is a Numeric type. +// Its c++ equivalent as "type name = initializer" is "NumericTypeID numeric_type = u32" +enum class ArgumentTypeID { + kScalar, + kInteger, + kTensor, + kBatchedTensor, + kStructure, + kEnumerated, + kInvalid +}; + +/// Converts a ArgumentTypeID enumerant to a string +char const *to_string(ArgumentTypeID type, bool pretty = false); + +/// Parses a ArgumentTypeID enumerant from a string +template <> +ArgumentTypeID from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Profiler typedefs +using ProviderVector = std::vector; +using DispositionMap = std::map; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Print vector for the report +template +std::ostream& operator<< (std::ostream& out, const std::vector& v) { + for (size_t i = 0; i < v.size(); ++i) { + out << to_string(v[i], true) << (i + 1u != v.size() ? "," : ""); + } + return out; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..faf317152473cac6dc62ecf8970cd1acfb2c1622 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -0,0 +1,333 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Gemm Profiler +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class GemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + + bool enable_sm90_mixed_dtype_shuffle_test{false}; + + // + // Methods + // + + /// Parses the problem + Status parse( + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::GemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::GemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::GemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::GemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// For mixed input dtype kernels + DeviceAllocation *Scale{nullptr}; // Scale tensor + DeviceAllocation *Zero{nullptr}; // Zero tensor + DeviceAllocation *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification + DeviceAllocation *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle + DeviceAllocation *packed_Scale{nullptr}; // Packed scale for int4 * fp8 + + cudaStream_t stream; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + std::vector gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + GemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~GemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GemmWorkspace &gemm_workspace, + gemm::GemmCoord const &problem_shape, + std::array const &leading_dim, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration according to flexible user setups + void update_result_( + PerformanceResult &result, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space, + gemm::GemmCoord const &problem_shape, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + GemmWorkspace &gemm_workspace); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h new file mode 100644 index 0000000000000000000000000000000000000000..154045295d6443d930ba53387366f4b8abe408a4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct GpuTimer { + + cudaEvent_t events[2]; + + // + // Methods + // + + GpuTimer(); + + GpuTimer(GpuTimer const&) = delete; + + GpuTimer(GpuTimer &&gpu_timer) noexcept; + + ~GpuTimer(); + + /// Records a start event in the stream, the flag is for cudaEventRecordWithFlags + void start(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Records a stop event in the stream, the flag is for cudaEventRecordWithFlags + void stop(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Records a stop event in the stream and synchronizes on the stream, the flag is for cudaEventRecordWithFlags + void stop_and_wait(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Returns the duration in milliseconds + double duration(int iterations = 1) const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..62d47990584cbb984935a00a267cff15dbb4f4e5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief GroupedGemm Profiler +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +// Profiler includes +#include "device_context.h" +#include "operation_profiler.h" +#include "options.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class GroupedGemmOperationProfiler : public OperationProfiler { +public: + /// Problem structure obtained from problem space + struct GroupedGemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGrouped}; + + std::vector problem_sizes; + std::vector> problem_sizes_3x; + + /// For exploration purposes + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + std::vector lda{0}; + std::vector ldb{0}; + std::vector ldc{0}; + + std::vector alpha; + std::vector beta; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + bool use_pdl{false}; + + /// Parses the problem + Status parse( + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + int64_t m(int group_idx) const { return problem_sizes[group_idx].m(); }; + int64_t n(int group_idx) const { return problem_sizes[group_idx].n(); }; + int64_t k(int group_idx) const { return problem_sizes[group_idx].k(); }; + + /// Total number of bytes loaded + int64_t bytes(library::GroupedGemmDescription const& operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::GroupedGemmDescription const& operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult& result, + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space); + }; + + struct BlockScalingWorkspace { + // host vector (per L2 workspace) of device vectors (per group) of device pointers + std::vector SFA_ptr_array_device; + std::vector SFB_ptr_array_device; + std::vector SFC_ptr_array_device; + std::vector SFD_ptr_array_device; + + // host vector (per group) of device tensors + // (where each batch of device allocation is for a L2 workspace) + std::vector SFA_ptr_array_host; + std::vector SFB_ptr_array_host; + std::vector SFC_ptr_array_host; + std::vector SFD_ptr_array_host; + std::vector SFD_reference_ptr_array_host; + + // matrix wide constant, not per-batch or per-group + DeviceAllocation* norm_constant; + }; + + // workspace contains the allocated blocks, arguments just contain the raw + // pointers + struct GroupedGemmWorkspace { + + // host vector (per L2 workspace) of device vectors (per group) of device pointers + std::vector A_ptr_array_device; + std::vector B_ptr_array_device; + std::vector C_ptr_array_device; + std::vector D_ptr_array_device; + std::vector reference_ptr_array_host; + + // host vector (per group) of device tensors + // (where each batch of device allocation is for a L2 workspace) + std::vector A_ptr_array_host; + std::vector B_ptr_array_host; + std::vector C_ptr_array_host; + std::vector D_ptr_array_host; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + /// *NOT* the number of groups in the grouped GEMM (we use `num_groups` in the profiler) + int problem_count{1}; + + DeviceAllocation* problem_sizes_array_device{nullptr}; + DeviceAllocation* problem_sizes_3x_array_device{nullptr}; + DeviceAllocation* lda_array_device{nullptr}; + DeviceAllocation* ldb_array_device{nullptr}; + DeviceAllocation* ldc_array_device{nullptr}; + DeviceAllocation* ldd_array_device{nullptr}; + + std::optional block_scales; + + library::GemmGroupedConfiguration configuration; + library::GroupedGemmBlockScaledArguments arguments; + + std::vector host_workspace; + DeviceAllocation device_workspace; + + cudaStream_t stream; + }; + +private: + void init_arguments(Options const& options) { + auto& arguments = gemm_workspace_.arguments; + // these get updated in each profiler run to ensure L2 cycling + arguments.ptr_A = gemm_workspace_.A_ptr_array_device[0]->data(); + arguments.ptr_B = gemm_workspace_.B_ptr_array_device[0]->data(); + arguments.ptr_C = gemm_workspace_.C_ptr_array_device[0]->data(); + arguments.ptr_D = gemm_workspace_.D_ptr_array_device[0]->data(); + + arguments.alpha = problem_.alpha.data(); + arguments.beta = problem_.beta.data(); + arguments.pointer_mode = library::ScalarPointerMode::kHost; + arguments.lda = static_cast(gemm_workspace_.lda_array_device->data()); + arguments.ldb = static_cast(gemm_workspace_.ldb_array_device->data()); + arguments.ldc = static_cast(gemm_workspace_.ldc_array_device->data()); + arguments.ldd = static_cast(gemm_workspace_.ldc_array_device->data()); + arguments.problem_sizes = + static_cast(gemm_workspace_.problem_sizes_array_device->data()); + arguments.problem_sizes_3x = static_cast*>( + gemm_workspace_.problem_sizes_3x_array_device->data()); + gemm_workspace_.arguments.problem_sizes_3x_host = problem_.problem_sizes_3x.data(); + gemm_workspace_.arguments.problem_count = problem_.problem_sizes.size(); + gemm_workspace_.arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}; + gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + arguments.sm_count = options.device.get_sm_count(0); + if (is_block_scaled) { + auto& block_scaled_ws = gemm_workspace_.block_scales.value(); + arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); + arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data(); + arguments.SFD = block_scaled_ws.SFD_ptr_array_device[0]->data(); + arguments.norm_constant = block_scaled_ws.norm_constant->data(); + } + else if (is_blockwise) { + auto& block_scaled_ws = gemm_workspace_.block_scales.value(); + arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); + arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data(); + } + } + +protected: + /// GEMM problem obtained from problem space + GroupedGemmProblem problem_; + + /// Device memory allocations + GroupedGemmWorkspace gemm_workspace_; + + bool is_block_scaled{false}; + bool is_blockwise{false}; + +public: + GroupedGemmOperationProfiler(Options const& options); + + virtual ~GroupedGemmOperationProfiler(); + + GroupedGemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream& out) const; + + /// Prints examples + virtual void print_examples(std::ostream& out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Measures performance results + virtual bool profile( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + +protected: + /// Initializes the performance result + void initialize_result_( + PerformanceResult& result, + Options const& options, + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space); + + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GroupedGemmWorkspace &gemm_workspace, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration for exploration parameters + void update_workspace_and_result_( + GroupedGemmWorkspace &gemm_workspace, + PerformanceResult &result, + ProblemSpace const &problem_space, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult& result, + Options const& options, + library::Operation const* operation, + void* arguments, + void* host_workspace, + void* device_workspace) override; + + /// Method to profile a CUTLASS Operation for the best configuration for a fixed shape + bool profile_cutlass_for_fixed_shape_( + Options const& options, + library::Operation const* operation, + ProblemSpace const& problem_space); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..446ef2c16739b28aaf038ca62bad6e3cdf667813 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h @@ -0,0 +1,287 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include +#include +#include +#include + +// CUTLASS includes +#include "cutlass/trace.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "performance_result.h" +#include "performance_report.h" +#include "problem_space.h" +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class OperationProfiler { +public: + + +protected: + // + // Data members + // + + /// Top-level operation kind + library::OperationKind kind_; + + /// Human readable description + std::string description_; + + /// Arguments parsed from command line + ArgumentDescriptionVector arguments_; + + /// List of providers used to verify and compare each result + ProviderVector verification_providers_; + + /// Model performance result initialized by the operation profiler with workload statistics + /// and reasonable default state. + PerformanceResult model_result_; + + /// Performance result vector constructed by profiling the operation + PerformanceResultVector results_; + +public: + + // + // Methods + // + + /// Ctor + OperationProfiler(); + + OperationProfiler( + Options const &options, + library::OperationKind kind, + ArgumentDescriptionVector const &arguments = ArgumentDescriptionVector(), + ProviderVector const & verification_providers = ProviderVector()); + + /// Destructor + virtual ~OperationProfiler(); + + /// Obtains the operation kind + library::OperationKind kind() const { return kind_; } + + /// Gets the schema description + std::string const &description() const; + + /// Returns a reference to the arguments + ArgumentDescriptionVector const &arguments() const { return arguments_; } + +public: + + // + // Basic overrides + // + + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const =0; + + /// Entry point to profile all operations in the manifest + virtual int profile_all( + Options const &options, + library::Manifest const &manifest, + DeviceContext &device_context); + +public: + + // + // Operation-specific phases of verification and profiling + // + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + +public: + + // + // Static helpers + // + + /// Sleep for a given duration in ms + static void sleep(int sleep_duration); + + /// Returns true if the current operation description satisfies the problem space + static bool satisfies( + library::OperationDescription const &op_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Compares tensors for equality + static Disposition compare_tensors( + Options const &options, + DeviceAllocation &experimental, + DeviceAllocation &reference, + int64_t count = 0); + + static void save_workspace( + DeviceContext &device_context, + Options const &options, + library::OperationDescription const &desc, + library::Provider provider, + library::Provider verification_provider = library::Provider::kInvalid); + + /// Helper to set a performance result member + static void set_argument( + PerformanceResult &result, + char const *name, + ProblemSpace const &problem_space, + std::string const &value); + + /// Helper to set a performance result member + static void set_argument( + PerformanceResult &result, + char const *name, + ProblemSpace const &problem_space, + int64_t value); + +protected: + + /// Sets operation description + static void initialize_result_( + PerformanceResult &result, + library::OperationDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Profiles the GPU kernel launched in `func` running simultaneously on all + /// requested devices. + Status profile_kernel_w_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams); + + Status profile_kernel_( + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams); + + /// Profiles the GPU kernel launched in `func` on the `stream` + Status profile_kernel_( + PerformanceResult& result, + Options const& options, + std::function const& func, + cudaStream_t stream = nullptr); + + /// Profiles the GPU kernel launched in `func` on the `stream` + Status profile_kernel_no_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, + cudaStream_t stream = nullptr); + +private: + /// finds string matches filter_string in operation_name + bool find_string_matches_( + std::string const &filter_string, + std::string const &operation_name); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Vector of owning operation profilers +using OperationProfilerVector = std::vector>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h new file mode 100644 index 0000000000000000000000000000000000000000..1a957b36eea35f7c0a5366645c3a62298ca56dea --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Command line options for performance test program +*/ + +#pragma once + +#include +#include +#include + +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/library/library.h" + +#include "enumerated_types.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Global options +class Options { +public: + + /// Cublas and cuDNN options + struct Library { + + // + // Data members + // + + /// Algorithm mode + AlgorithmMode algorithm_mode; + + /// Algorithm enumerants + std::vector algorithms; + + // + // Methods + // + + explicit Library(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + }; + + /// Options related to the selected device + struct Device { + + /// Device ID + std::vector devices; + + /// Number of total devices + /// This is not set by the user, it is set by automatically + int num_devices; + + /// CUDA Device properties + std::vector properties; + + /// Total memory allocation on each device + size_t maximum_capacity; + + private: + /// SM Count + /// Limits the number of SMs to use on each device + int sm_count; + + // + // Methods + // + public: + explicit Device(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + void print_device_info(std::ostream &out) const; + + /// Returns the device ID from a device index + int device_id(size_t device_index) const; + + /// Returns the sm_count if set, otherwise returns the number of SMs on the device + int get_sm_count(int device_index) const; + + /// Returns the compute capability of the listed devices (e.g. 70, 75, 80, etc.) + int compute_capability(int device_index) const; + }; + + /// Options related to initializing input tensors + struct Initialization { + + /// If true, data is initialized randomly. If false, no initialization is performed after + /// allocating tensors. + bool enabled; + + /// If true, data distribution is set by the user and is not allowed to change + /// If false, data distribution is allowed to change based on element_type (library::NumericTypeID) + bool fix_data_distribution; + + /// Data distribution for input tensors + Distribution data_distribution; + + /// Source of random tensor elements + library::Provider provider; + + /// Random number generator seed. + int seed; + + // + // Methods + // + + explicit Initialization(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Helper to parse a Distribution object from the command line parser + static void get_distribution( + cutlass::CommandLine const &args, + std::string const &arg, + cutlass::Distribution &dist); + }; + + /// Options related to verification of the result + struct Verification { + + // + // Data members + // + + /// If true, kernels are verified before they are profiled + bool enabled; + + /// If true, causes profiler to return an error code if no reference check is run. + /// Only valid when verification is enabled. + bool required; + + /// Relative error threshold - zero to require bit-level consistency + double epsilon; + + /// Values smaller than this are assumed to be zero + double nonzero_floor; + + /// List of providers used to verify each result + ProviderVector providers; + + /// Indicates when to save the workspace + SaveWorkspace save_workspace; + + // + // Methods + // + + explicit Verification(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Returns true if a provider is enabled + bool provider_enabled(library::Provider provider) const; + + /// Returns the index of a provider if its enabled + size_t index(library::Provider provider) const; + }; + + /// Options related to profiling + struct Profiling { + + /// Number of workspaces to rotate through to avoid cache-resident working sets + int workspace_count{0}; + + /// Number of iterations to warmup each kernel prior to profiling + int warmup_iterations{10}; + + /// Number of iterations to profile each kernel - if 0, kernels are launched up to the profiling duration + /// This will always override profiling-duration and min-iterations. + int iterations{100}; + + /// Time to spend profiling each kernel (ms) + int duration{10}; + + /// Minimum number of iterations to profile + int min_iterations{10}; + + /// If true, profiling with cuda graph enabled. + bool use_cuda_graphs{false}; + + /// If enabled, the CUTLASS profiler searches for the best-performing kernel + /// within the subset of kernels matching a kernel filter regex. The best + /// performance is determined by screening over a set of predefined M/N/K + /// sizes and performance-related parameters, including cluster shapes, + /// swizzle sizes, and rasterization orders. + /// For now, it only supports legacy GEMM and blockscaled GEMM. + bool enable_kernel_performance_search{false}; + + /// If enabled, the CUTLASS profiler searches for the best-performing kernel + /// for a given M/N/K problem size by evaluating various performance-related + /// parameters such as cluster shapes, swizzle sizes, and rasterization orders. + /// For now, it only supports legacy GEMM and blockscaled GEMM. + bool enable_best_kernel_for_fixed_shape{false}; + + /// Number of ms to sleep between profiling periods (ms) + int sleep_duration{50}; + + /// If true, profiling is actually conducted. + bool enabled{true}; + + /// If true, profiling returns an error code if no kernels are found to match the filters. + bool error_on_no_match{false}; + + /// If true, profiling returns an error code if no kernel are profiled + // Sometimes the kernel matches but failed to profile (e.g. can_implement() error) + bool error_if_nothing_is_profiled{false}; + + /// List of providers of each functionality to be profiled + ProviderVector providers; + + // + // Methods + // + + explicit Profiling(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Returns true if a provider is enabled + bool provider_enabled(library::Provider provider) const; + + /// Returns the index of a provider if its enabled + size_t index(library::Provider provider) const; + }; + + /// Options related to reporting + struct Report { + + /// If true, result is appended to possibly existing file + bool append; + + /// Path to a file containing results + std::string output_path; + + /// Path to a file containing junit xml results + std::string junit_output_path; + + /// Sequence of tags to attach to each result + std::vector> pivot_tags; + + /// If true, reports status of all kernels including those that were + /// not run for the given arguments + bool report_not_run; + + /// Prints human-readable text to stdout. If false, nothing is written to stdout + bool verbose; + + /// Sort results by flops-per-byte + bool sort_flops_per_byte; + + /// Sort results by flops-per-second + bool sort_flops_per_sec; + + /// Prints the name of the kernel being profiled before running the kernel. + /// This is useful for determining which kernel is causing a run of the profiler to hang + bool print_kernel_before_running; + + // + // Methods + // + + explicit Report(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + }; + + /// Options related to printing usage and version information + struct About { + + /// If true, usage is printed and the program ends. + bool help; + + /// Prints version string + bool version; + + /// Print information about devices + bool device_info; + + // + // Methods + // + + explicit About(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + static void print_version(std::ostream &out); + }; + +public: + + // + // Data members + // + + /// Top-level execution mode + ExecutionMode execution_mode; + + /// Name of math function to profile + library::OperationKind operation_kind; + + /// Vector of operation name substrings + std::vector operation_names; + + /// Map of problems to run for each operation + /// [operation_name] -> vector of problems, each problem specified as a vector of [argument name] -> [argument value] + std::unordered_map> operation_problems; + + /// Vector of operation name substrings + std::vector excluded_operation_names; + + + // + // Detailed configuration options + // + + /// Configuration + CommandLine cmdline; + Device device; + Initialization initialization; + Library library; + Verification verification; + Profiling profiling; + Report report; + About about; + +public: + + explicit Options(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out) const; + + static std::string indent_str(int indent); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h new file mode 100644 index 0000000000000000000000000000000000000000..07102c99bc0f38a071e1ab828aab30678a3e2d44 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h @@ -0,0 +1,128 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Class performing output during profiling +*/ + +#pragma once + +#include +#include + +// CUTLASS Profiler includes +#include "options.h" +#include "enumerated_types.h" +#include "performance_result.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +class PerformanceReport { +private: + + /// Reference to options + Options const &options_; + + /// Operation kind + library::OperationKind op_kind_; + + /// Operation file name containing performance report of op_kind + std::string op_file_name_; + + /// Output file containing results + std::ofstream output_file_; + + /// Operation file name containing junit performance report of op_kind + std::string op_junit_file_name_; + + /// Output file containing junit results + std::ofstream junit_output_file_; + + /// Flag indicating the performance report is valid + bool good_; + + /// Vector of argument names + std::vector argument_names_; + + /// Counter uniquely identifying problem within the report + size_t problem_index_; + + /// Collection of all results + PerformanceResultVector concatenated_results_; + +public: + + PerformanceReport(Options const &options, std::vector const &argument_names, library::OperationKind const &op_kind); + ~PerformanceReport(); + + bool good() const { return good_; } + + void next_problem(); + void append_result(PerformanceResult result); + void sort_flops_per_byte(PerformanceResultVector &results); + void sort_flops_per_sec(PerformanceResultVector &results); + void append_results(PerformanceResultVector const &results); + +public: + + /// Prints the CSV header + std::ostream & print_csv_header_(std::ostream &out); + + /// Prints the CSV + std::ostream & print_result_csv_(std::ostream &out, PerformanceResult const &result); + + /// @defgroup jUnit Result Generation + /// Functions related to generation of the jUnit results + /// @{ + + std::ostream & print_junit_header_(std::ostream &out); + std::ostream & print_junit_result_(std::ostream &out, PerformanceResult const &result); + std::ostream & print_junit_footer_(std::ostream &out); + + /// @} + + /// Prints the result in human readable form + std::ostream & print_result_pretty_( + std::ostream &out, + PerformanceResult const &result, + bool use_shell_coloring = true); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h new file mode 100644 index 0000000000000000000000000000000000000000..986ac89bc86a267ce8fb181a986f28f3f0936566 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h @@ -0,0 +1,137 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +// CUTLASS Profiler includes +#include "enumerated_types.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performance result object +struct PerformanceResult { + + /// Index of problem + size_t problem_index; + + /// library::Provider + library::Provider provider; + + /// Operation kind + library::OperationKind op_kind; + + /// CUTLASS status result from kernels (success or failure) + // Status does information on verification + Status status; + + /// Outcome of verification (worst case verification result) + Disposition disposition; + + /// Outcome of verification (all verification results) + DispositionMap verification_map; + + /// Operation name + std::string operation_name; + + /// Stringified vector of argument values + std::vector > arguments; + + /// Number of bytes read or written + int64_t bytes; + + /// Number of DL flops performed by the math function + int64_t flops; + + /// Average runtime in ms + double runtime; + + /// Average runtime in ms per device + std::vector runtime_vector; + + // + // Members + // + + /// Ctor + PerformanceResult(): + problem_index(0), + op_kind(library::OperationKind::kInvalid), + provider(library::Provider::kInvalid), + disposition(Disposition::kNotRun), + status(Status::kInvalid), + bytes(0), + flops(0), + runtime(0) + { } + + // Copy constructor for deep copy + PerformanceResult(const PerformanceResult& other) = default; + + // Explicitly define copy assignment operator + PerformanceResult& operator=(const PerformanceResult& other) = default; + + /// Returns true if the runtime is valid + bool good() const { + return runtime > 0; + } + + /// Math throughput in units of GFLOP/s + double gflops_per_sec() const { + return double(flops) / runtime / 1.0e6; + } + + /// memory bandwidth in units of GiB/s + double gbytes_per_sec() const { + return double(bytes) / double(1 << 30) / runtime * 1000.0; + } + +}; + +using PerformanceResultVector = std::vector; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h new file mode 100644 index 0000000000000000000000000000000000000000..9bdbec657c10cff0dafebd2cb6cd52057f3695c9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h @@ -0,0 +1,1039 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief + + "Any sufficiently complicated C or Fortran program contains an ad-hoc, informally-specified, + bug-ridden, slow implementation of half of Common Lisp." + + - Greenspun's Tenth Rule of Programming + + + cutlass::profiler::ProblemSpace defines a set of data structures which represent the Cartesian + product of sequences defined by integer ranges, lists of scalars, and sets of enumerated types. + + These permit a single invocation of the CUTLASS Profiler to iterate over a large set of problems, + verify and profile various operations when they are compatible with the command line, and + construct data tables of results that are convenient inputs to post processing in Excel or Pandas. + + By executing multiple problems per invocation, startup overheads may be amortized across many + kernel launches. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include +#include +#include + +// CUTLASS Utility includes +#include "cutlass/util/command_line.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +// Profiler includes +#include "enumerated_types.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines the argument schema +struct ArgumentDescription { + + /// Type of argument + ArgumentTypeID type; + + /// Prioritized array of aliases used in command line parsing + std::vector aliases; + + /// Description of argument + std::string description; + + // + // Methods + // + + /// Default ctor + ArgumentDescription(): + type(ArgumentTypeID::kInvalid) { } + + /// Constructor with aliases + ArgumentDescription( + ArgumentTypeID type_, + std::vector const &aliases_, + std::string const &description_ + ): + type(type_), aliases(aliases_), description(description_) { } +}; + +/// Vector of arguments +using ArgumentDescriptionVector = std::vector; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Base class for kernel arguments +struct KernelArgument { + + // + // Type definitions + // + + /// Value base class + struct Value { + + KernelArgument const *argument; + bool not_null; + + // + // Methods + // + + Value( + KernelArgument const *argument_ = nullptr, + bool not_null_ = true + ): argument(argument_), not_null(not_null_) { } + + virtual ~Value() { } + + virtual std::ostream &print(std::ostream &out) const =0; + }; + + /// Abstract base class to iterate over values within arguments + struct ValueIterator { + + /// Indicates type of kernel argument + KernelArgument const *argument; + + /// If the iterator points to an argument that is null, it needs to be distinguished + /// from end. + bool null_argument; + + // + // Methods + // + + /// Constructs a value iterator - no methods are valid if argument_ == nullptr + ValueIterator( + KernelArgument const *argument_ = nullptr, + bool null_argument_ = false): + argument(argument_), null_argument(null_argument_) { + + if (!argument_->not_null()) { + null_argument = true; + } + } + + virtual ~ValueIterator() { } + + /// Advances to next point in range + virtual void operator++() = 0; + + /// Compares against another value iterator - must be of the same KernelArgument type + virtual bool operator==(ValueIterator const &it) const = 0; + + /// Returns a unique_ptr object pointing to a newly created value object + virtual std::unique_ptr at() const = 0; + + /// Gets the type of the iterator + ArgumentTypeID type() const { + return argument->description->type; + } + + /// Helper to compute inequality + bool operator!=(ValueIterator const &it) const { + return !(*this == it); + } + + std::ostream &print(std::ostream &out) const; + }; + + // + // Data members + // + + /// Describes the argument + ArgumentDescription const *description; + + /// Parent node + KernelArgument *parent; + + /// Sequence in which the kernel argument is to be iterated over. + /// Smaller means faster changing. -1 is don't care + int ordinal; + + // + // Methods + // + + /// Default ctor + KernelArgument( + ArgumentDescription const *description_ = nullptr, + KernelArgument *parent_ = nullptr, + int ordinal_ = -1 + ): description(description_), parent(parent_), ordinal(ordinal_) { } + + virtual ~KernelArgument(); + + /// Returns true if the kernel argument iself is empty + virtual bool not_null() const =0; + + /// Returns a string name for debugging + std::string qualified_name() const { + if (description) { + if (description->aliases.empty()) { + return ""; + } + return description->aliases.front(); + } + return ""; + } + + virtual std::unique_ptr begin() const =0; + virtual std::unique_ptr end() const =0; +}; + +using KernelArgumentVector = std::vector>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a scalar argument type as a string that is lexically cast to the appropriate kernel +/// type. +struct ScalarArgument : public KernelArgument { + + // + // Type definitions + // + + /// Value type + struct ScalarValue : public KernelArgument::Value { + + std::string value; + + // + // Methods + // + + ScalarValue( + std::string const &value_ = "", + ScalarArgument const *argument = nullptr, + bool not_null_ = true + ); + + virtual std::ostream &print(std::ostream &out) const; + }; + + using ValueCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct ScalarValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit ScalarValueIterator(ScalarArgument const *argument = nullptr); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + /// Set of possible values + ValueCollection values; + + // + // Methods + // + + /// Default ctor + explicit ScalarArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Closed range supporting additive increment +struct Range { + + // + // Type definitions + // + + enum class Mode { + kSequence, + kRandom, + kRandomLog2, + kInvalid + }; + + struct Iterator { + + int64_t value; + int64_t increment; + Range const *range; + + // + // Methods + // + + Iterator( + int64_t value_ = 0, + int64_t increment_ = 1, + Range const *range_ = nullptr + ): + value(value_), increment(increment_), range(range_) { } + + Iterator & operator++() { + value += increment; + return *this; + } + + Iterator operator++(int) { + Iterator self(*this); + ++(*this); + return self; + } + + bool operator==(Iterator const &it) const { + return value == it.value; + } + + bool operator!=(Iterator const &it) const { + return !(*this == it); + } + + static int64_t round(int64_t value, int64_t divisible) { + int64_t rem = (value % divisible); + + // Round either up or down + if (rem > divisible / 2) { + value += (divisible - rem); + } + else { + value -= rem; + } + + return value; + } + + int64_t at() const { + if (!range) { + return value; + } + + switch (range->mode) { + case Mode::kSequence: return value; + + case Mode::kRandom: { + double rnd = double(range->minimum) + + double(std::rand()) / double(RAND_MAX) * (double(range->maximum) - double(range->minimum)); + + int64_t value = int64_t(rnd); + + return round(value, range->divisible); + } + break; + + case Mode::kRandomLog2: { + double lg2_minimum = std::log(double(range->minimum)) / std::log(2.0); + double lg2_maximum = std::log(double(range->maximum)) / std::log(2.0); + double rnd = lg2_minimum + double(std::rand()) / double(RAND_MAX) * (lg2_maximum - lg2_minimum); + + int64_t value = int64_t(std::pow(2.0, rnd)); + + return round(value, range->divisible); + } + break; + default: break; + } + return value; + } + + int64_t operator*() const { + return at(); + } + }; + + // + // Data members + // + + int64_t first; ///< first element in range + int64_t last; ///< last element in range + int64_t increment; ///< additive increment between values + + Mode mode; ///< mode selection enables alternative values + int64_t minimum; ///< minimum value to return + int64_t maximum; ///< maximum value to return + int64_t divisible; ///< rounds value down to an integer multiple of this value + + // + // Methods + // + + /// Default constructor - range acts as a scalar + Range(int64_t first_ = 0): first(first_), last(first_), increment(1), mode(Mode::kSequence), minimum(0), maximum(0), divisible(1) { } + + /// Range acts as a range + Range( + int64_t first_, + int64_t last_, + int64_t increment_ = 1, + Mode mode_ = Mode::kSequence, + int64_t minimum_ = 0, + int64_t maximum_ = 0, + int64_t divisible_ = 1 + ): first(first_), last(last_), increment(increment_), mode(mode_), minimum(minimum_), maximum(maximum_), divisible(divisible_) { + + // Helpers to avoid constructing invalid ranges + if (increment > 0) { + if (last < first) { + std::swap(last, first); + } + } + else if (increment < 0) { + if (first < last) { + std::swap(last, first); + } + } + else if (last != first) { + last = first; + increment = 1; + } + } + + /// Helper to construct a sequence range + static Range Sequence(int64_t first_, int64_t last_, int64_t increment_ = 1) { + return Range(first_, last_, increment_, Mode::kSequence); + } + + /// Helper to construct a range that is a random distribution + static Range Random(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { + return Range(1, count_, 1, Mode::kRandom, minimum_, maximum_, divisible_); + } + + /// Helper to construct a range that is a random distribution over a log scale + static Range RandomLog2(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { + return Range(1, count_, 1, Mode::kRandomLog2, minimum_, maximum_, divisible_); + } + + /// Returns an iterator to the first element within the range + Iterator begin() const { + return Iterator(first, increment, this); + } + + /// Returns an iterator to the first element *after* the range + Iterator end() const { + return Iterator(first + ((last - first)/increment + 1) * increment, increment, this); + } +}; + +/// Integer-valued argument - represented as a list of integer-valued ranges +struct IntegerArgument : public KernelArgument { + + // + // Type definitions + // + + /// Value type + struct IntegerValue : public KernelArgument::Value { + + int64_t value; + + // + // Methods + // + + IntegerValue( + int64_t value_ = 0, + IntegerArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + /// Collection of ranges represent the IntegerArgument's state + using RangeCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct IntegerValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + RangeCollection::const_iterator range_it; + Range::Iterator value_it; + + // + // Methods + // + + IntegerValueIterator(); + IntegerValueIterator(IntegerArgument const *argument); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + /// Set of possible values + RangeCollection ranges; + + // + // Methods + // + + /// Default ctor + IntegerArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + bool _not_null = !ranges.empty(); + return _not_null; + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure defining the data type of tensors +struct TensorArgument : public KernelArgument { + + // + // Type definitions + // + + struct TensorDescription { + + /// Data type of elements + library::NumericTypeID element; + + /// Layout definition + library::LayoutTypeID layout; + + /// Computed extent + std::vector extent; + + /// Enables directly specifying stride value used to size tensor + std::vector stride; + + // + // Methods + // + + TensorDescription( + library::NumericTypeID element_ = library::NumericTypeID::kUnknown, + library::LayoutTypeID layout_ = library::LayoutTypeID::kUnknown, + std::vector extent_ = std::vector(), + std::vector stride_ = std::vector() + ): + element(element_), layout(layout_), extent(extent_), stride(stride_) {} + }; + + using ValueCollection = std::vector; + + /// Value structure + struct TensorValue : public KernelArgument::Value { + + TensorDescription desc; + + // + // Methods + // + + TensorValue( + TensorDescription const &desc_ = TensorDescription(), + TensorArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + /// Abstract base class to iterate over values within arguments + struct TensorValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit TensorValueIterator(TensorArgument const *argument_); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + /// Set of possible values + ValueCollection values; + + // + // Methods + // + + /// Default ctor + explicit TensorArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Numeric data type +struct EnumeratedTypeArgument : public KernelArgument { + + // + // Type definitions + // + + struct EnumeratedTypeValue : public KernelArgument::Value { + + /// Data type of element + std::string element; + + // + // Methods + // + + EnumeratedTypeValue( + std::string const &element_ = std::string(), + EnumeratedTypeArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + using ValueCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct EnumeratedTypeValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit EnumeratedTypeValueIterator(EnumeratedTypeArgument const *argument_ = nullptr); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + ValueCollection values; + + // + // Members + // + + /// Default ctor + explicit EnumeratedTypeArgument(ArgumentDescription const *description): + KernelArgument(description) {} + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Object storing the space argument values +class ProblemSpace { +public: + + /// Tuple of arguments + using Problem = std::vector>; + + /// Type used to iterator over things + using IteratorVector = std::vector>; + + /// Iterates over points in the design space + class Iterator { + private: + + /// One iterator per argument + IteratorVector iterators; + + public: + + // + // Methods + // + + explicit Iterator(); + Iterator(ProblemSpace const &problem_space); + Iterator(Iterator &&it); + + // Rule of three + Iterator(Iterator const &) = delete; + Iterator &operator=(Iterator const &it) = delete; + ~Iterator() = default; + + /// Pre-increment - advances to next point in argument range + void operator++(); + + /// Gets the current argument value + Problem at() const; + + /// Moves iterator to end + void move_to_end(); + + /// Equality operator + bool operator==(Iterator const &it) const; + + /// Inequality operator + bool operator!=(Iterator const &it) const { + return !(*this == it); + } + + /// Helper to call at() method + Problem operator*() const { + return at(); + } + + /// Helper to print iterator state + std::ostream & print(std::ostream &out) const; + + private: + + /// Helper for recursively constructing iterators + void construct_(KernelArgument const *argument); + }; + +public: + + // + // Data members + // + + KernelArgumentVector arguments; + + /// Map of argument names to their position within the argument vector + std::unordered_map argument_index_map; + +public: + + // + // Methods + // + + /// Default ctor + ProblemSpace() = default; + + /// Constructs a problem space from a vector of arguments. This vector must outlive + /// the ProblemSpace object, which stores pointers to objects within the + /// ArgumentDescriptionVector. + ProblemSpace(ArgumentDescriptionVector const &schema, CommandLine const &cmdline); + + Iterator begin() const; // returns an iterator to the first point in the range + Iterator end() const; // returns an iterator to the first point after the range + + /// Returns the index of an argument by name + size_t argument_index(char const *name) const; + + /// Gets all argument names as an ordered vector + std::vector argument_names() const; + + /// Returns the number of dimensions of the problem space + size_t rank() const { return arguments.size(); } + +private: + + /// Helper for recursively cloning + void clone_( + KernelArgumentVector &kernel_args, + ArgumentDescription const *arg_desc); + + /// Parses command line argument + void parse_( + KernelArgument *arg, + CommandLine const &cmdline); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexically casts an argument to an int if it is defined. Returns true if not null. +bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int( + int &int_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int( + int64_t &int_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +bool arg_as_bool(bool &bool_value, KernelArgument::Value const *value_ptr); + +bool arg_as_bool(bool &bool_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_NumericTypeID(library::NumericTypeID &numeric_type, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_NumericTypeID( + library::NumericTypeID &numeric_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_LayoutTypeID(library::LayoutTypeID &layout_type, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_LayoutTypeID( + library::LayoutTypeID &layout_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_OpcodeClassID(library::OpcodeClassID &opcode_class, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_OpcodeClassID( + library::OpcodeClassID &opcode_class, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID(library::SplitKMode &split_k_mode, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID( + library::SplitKMode &split_k_mode, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ConvModeID(library::ConvModeID &conv_mode, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ConvModeID( + library::ConvModeID &conv_mode, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_IteratorAlgorithmID(library::IteratorAlgorithmID &iterator_algorithm, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_IteratorAlgorithmID( + library::IteratorAlgorithmID &iterator_algorithm, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype(library::RuntimeDatatype &runtime_datatype, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype( + library::RuntimeDatatype &runtime_datatype, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RasterOrder(library::RasterOrder &raster_order, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RasterOrder( + library::RasterOrder &raster_order, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ProviderID(library::Provider &provider, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ProviderID( + library::Provider &provider, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. +bool arg_as_scalar( + std::vector &bytes, + library::NumericTypeID numeric_type, + KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. +bool arg_as_scalar( + std::vector &bytes, + library::NumericTypeID numeric_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +bool arg_as_string( + std::string& arg, + char const* name, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + +/// Returns true if a tensor description satisfies a `tensor` value +bool tensor_description_satisfies( + library::TensorDescription const &tensor_desc, + TensorArgument::TensorValue const *value_ptr); + +/// Returns true if a tensor description satisfies a `tensor` value +bool tensor_description_satisfies( + library::TensorDescription const &tensor_desc, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Returns true if a conv kind satisfies the value +bool conv_kind_satisfies( + library::ConvKind const &conv_kind, + EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); + +/// Returns true if a conv kind satisfies the value +bool conv_kind_satisfies( + library::ConvKind const &conv_kind, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Returns true if a iterator algorithm satisfies the value +bool iterator_algorithm_satisfies( + library::IteratorAlgorithmID const &iterator_algorithm, + EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); + +/// Returns true if a iterator algorithm satisfies the value +bool iterator_algorithm_satisfies( + library::IteratorAlgorithmID const &iterator_algorithm, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ba47a6832077984c334a5467257a151735b088b3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class Rank2KOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct RankKProblem { + int64_t n; + int64_t k; + int64_t lda; + int64_t ldb; + int64_t ldc; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + RankKProblem(): + n(16), k(16), lda(0), ldc(0), + fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::RankKDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::RankKDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct RankKWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + RankKWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + RankKProblem problem_; + + /// Device memory allocations + RankKWorkspace rank_k_workspace_; + + +public: + // + // Methods + // + + /// Ctor + Rank2KOperationProfiler(Options const &options); + + /// Destructor + virtual ~Rank2KOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..fff190a7570cd5811c6e5de6284bf96e40c404b7 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h @@ -0,0 +1,227 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class RankKOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct RankKProblem { + int64_t n; + int64_t k; + int64_t lda; + int64_t ldc; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + RankKProblem(): + n(16), k(16), lda(0), ldc(0), + fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::RankKDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::RankKDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct RankKWorkspace { + + DeviceAllocation *A; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + RankKWorkspace(): + A(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + RankKProblem problem_; + + /// Device memory allocations + RankKWorkspace rank_k_workspace_; + + +public: + // + // Methods + // + + /// Ctor + RankKOperationProfiler(Options const &options); + + /// Destructor + virtual ~RankKOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..0c81ef4637175a6de1f44cedddf319436aaff24d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h @@ -0,0 +1,173 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for reduction operation + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class ReductionOperationProfiler : public OperationProfiler { +public: + + + /// Workspace used + struct ReductionWorkspace { + + /// Conv device allocations + DeviceAllocation *Workspace; + DeviceAllocation *Source; + DeviceAllocation *Destination; + DeviceAllocation *Reference; + + /// Library configuration and arguments + library::ReductionConfiguration configuration; + library::ReductionArguments arguments; + + /// Buffer used for the cutlass operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + ReductionWorkspace(): + Workspace(nullptr), Source(nullptr), Destination(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// Reduction problem obtained from problem space + MatrixCoord problem_; + + /// Device memory allocations + ReductionWorkspace conv_workspace_; + + +public: + // + // Methods + // + + /// Ctor + ReductionOperationProfiler(Options const &options); + + /// Destructor + virtual ~ReductionOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..60204d8c9d458ab12020a6492de23174739aa584 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h @@ -0,0 +1,214 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "gemm_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class SparseGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct SparseGemmProblem { + int64_t m; + int64_t n; + int64_t k; + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t lde; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + static int const sparse = 2; + // every 128b ElementA uses one elementE + int elements_per_128b; + + // + // Methods + // + + SparseGemmProblem(): + m(16), n(16), k(16), lda(0), ldb(0), ldc(0), lde(0), split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct SparseGemmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *E; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::SparseGemmConfiguration configuration; + library::SparseGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + SparseGemmWorkspace(): + A(nullptr), B(nullptr), C(nullptr), E(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + // GEMM problem + SparseGemmProblem problem_; + + /// Device memory allocations + SparseGemmWorkspace gemm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + SparseGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~SparseGemmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..94ded5e803bf914e5ae8c4ebb867cfe42ef829bc --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class SymmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct SymmProblem { + int64_t m; + int64_t n; + int64_t lda; + int64_t ldb; + int64_t ldc; + SideMode side_mode; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + SymmProblem(): + m(16), n(16), lda(0), ldb(0), ldc(0), + side_mode(SideMode::kInvalid), fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::SymmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::SymmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct SymmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::SymmConfiguration configuration; + library::SymmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + SymmWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + SymmProblem problem_; + + /// Device memory allocations + SymmWorkspace symm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + SymmOperationProfiler(Options const &options); + + /// Destructor + virtual ~SymmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..9f21dafa0ecc869840fdba0a9c4414a89bbf4a7d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class TrmmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct TrmmProblem { + int64_t m; + int64_t n; + int64_t lda; + int64_t ldb; + int64_t ldd; + SideMode side_mode; + FillMode fill_mode; + DiagType diag_type; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + TrmmProblem(): + m(16), n(16), lda(0), ldb(0), ldd(0), split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct TrmmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *D; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::TrmmConfiguration configuration; + library::TrmmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + TrmmWorkspace(): + A(nullptr), B(nullptr), D(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + TrmmProblem problem_; + + /// Device memory allocations + TrmmWorkspace trmm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + TrmmOperationProfiler(Options const &options); + + /// Destructor + virtual ~TrmmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c2727c989e645eca8e67a5d8d50391ced803cffa --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp @@ -0,0 +1,67 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +struct GPU_Clock +{ + GPU_Clock() { + cudaEventCreate(&start_); + cudaEventCreate(&stop_); + cudaEventRecord(start_); + } + + ~GPU_Clock() { + cudaEventDestroy(start_); + cudaEventDestroy(stop_); + } + + void start() { + cudaEventRecord(start_); + } + + float milliseconds() { + cudaEventRecord(stop_); + cudaEventSynchronize(stop_); + float time; + cudaEventElapsedTime(&time, start_, stop_); + return time; + } + + float seconds() { + return milliseconds() * float(1e-3); + } + + private: + cudaEvent_t start_, stop_; +}; diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h new file mode 100644 index 0000000000000000000000000000000000000000..c95bd1cbeb56cc566394b155ea7ac24f07c28162 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h @@ -0,0 +1,324 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * Utility for parsing command line arguments + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass { + +/****************************************************************************** + * command_line + ******************************************************************************/ + +/** + * Utility for parsing command line arguments + */ +struct CommandLine { + std::vector keys; + std::vector values; + std::vector args; + + /** + * Constructor + */ + CommandLine(int argc, const char** argv) { + using namespace std; + + for (int i = 1; i < argc; i++) { + string arg = argv[i]; + + if ((arg[0] != '-') || (arg[1] != '-')) { + args.push_back(arg); + continue; + } + + string::size_type pos; + string key, val; + if ((pos = arg.find('=')) == string::npos) { + key = string(arg, 2, arg.length() - 2); + val = ""; + } else { + key = string(arg, 2, pos - 2); + val = string(arg, pos + 1, arg.length() - 1); + } + + keys.push_back(key); + values.push_back(val); + } + } + + /** + * Constructor to represent a command line from a map of [argument] -> [value] + */ + CommandLine(std::unordered_map& arg_map) { + for (const auto& [key, value] : arg_map) { + keys.push_back(key); + values.push_back(value); + } + } + + /** + * Checks whether a flag "--" is present in the commandline + */ + bool check_cmd_line_flag(const char* arg_name) const { + using namespace std; + + for (int i = 0; i < int(keys.size()); ++i) { + if (keys[i] == string(arg_name)) return true; + } + return false; + } + + /** + * Returns number of naked (non-flag and non-key-value) commandline parameters + */ + size_t num_naked_args() const { + return args.size(); + } + + /** + * Print naked (non-flag and non-key-value) commandline parameters + */ + void print_naked_args(std::ostream &out) const { + for (auto arg : args) { + out << " " << arg <<"\n"; + } + } + + /** + * Returns the commandline parameter for a given index (not including flags) + */ + template + void get_cmd_line_argument(size_t index, value_t& val) const { + using namespace std; + if (index < args.size()) { + istringstream str_stream(args[index]); + str_stream >> val; + } + } + + /** + * Obtains the boolean value specified for a given commandline parameter --= + */ + void get_cmd_line_argument(const char* arg_name, bool& val, bool _default) const { + val = _default; + if (check_cmd_line_flag(arg_name)) { + std::string value; + get_cmd_line_argument(arg_name, value); + + val = !(value == "0" || value == "false"); + } + } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val) const { + + get_cmd_line_argument(arg_name, val, val); + } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val, + value_t const& _default) const { + using namespace std; + + val = _default; + + for (int i = 0; i < int(keys.size()); ++i) { + if (keys[i] == string(arg_name)) { + istringstream str_stream(values[i]); + str_stream >> val; + } + } + } + + /** + * Returns the values specified for a given commandline parameter --=,* + */ + template + void get_cmd_line_arguments(const char* arg_name, + std::vector& vals, + char sep = ',') const { + using namespace std; + + if (check_cmd_line_flag(arg_name)) { + // Clear any default values + vals.clear(); + + // Recover from multi-value string + for (size_t i = 0; i < keys.size(); ++i) { + if (keys[i] == string(arg_name)) { + string val_string(values[i]); + separate_string(val_string, vals, sep); + } + } + } + } + + /** + * Returns the values specified for a given commandline parameter + * --=,* + */ + void get_cmd_line_argument_pairs(const char* arg_name, + std::vector >& tokens, + char delim = ',', + char sep = ':') const { + if (check_cmd_line_flag(arg_name)) { + std::string value; + get_cmd_line_argument(arg_name, value); + + tokenize(tokens, value, delim, sep); + } + } + + /** + * Returns a list of ranges specified for a given commandline parameter + * --=,* + */ + void get_cmd_line_argument_ranges(const char* arg_name, + std::vector >& vals, + char delim = ',', + char sep = ':') const { + std::vector ranges; + get_cmd_line_arguments(arg_name, ranges, delim); + + for (std::vector::const_iterator range = ranges.begin(); + range != ranges.end(); ++range) { + + std::vector range_vals; + separate_string(*range, range_vals, sep); + vals.push_back(range_vals); + } + } + + /** + * The number of pairs parsed + */ + int parsed_argc() const { return (int)keys.size(); } + + //------------------------------------------------------------------------- + // Utility functions + //------------------------------------------------------------------------- + + /// Tokenizes a comma-delimited list of string pairs delimited by ':' + static void tokenize(std::vector >& tokens, + std::string const& str, + char delim = ',', + char sep = ':') { + // Home-built to avoid Boost dependency + size_t s_idx = 0; + size_t d_idx = std::string::npos; + while (s_idx < str.size()) { + d_idx = str.find_first_of(delim, s_idx); + + size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size()); + size_t sep_idx = str.find_first_of(sep, s_idx); + size_t offset = 1; + if (sep_idx == std::string::npos || sep_idx >= end_idx) { + sep_idx = end_idx; + offset = 0; + } + + std::pair item( + str.substr(s_idx, sep_idx - s_idx), + str.substr(sep_idx + offset, end_idx - sep_idx - offset)); + + tokens.push_back(item); + s_idx = end_idx + 1; + } + } + + /// Tokenizes a comma-delimited list of string pairs delimited by ':' + static void tokenize(std::vector& tokens, + std::string const& str, + char delim = ',', + char sep = ':') { + typedef std::vector > TokenVector; + typedef TokenVector::const_iterator token_iterator; + + std::vector > token_pairs; + tokenize(token_pairs, str, delim, sep); + for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) { + tokens.push_back(tok->first); + } + } + + template + static void separate_string(std::string const& str, + std::vector& vals, + char sep = ',') { + std::istringstream str_stream(str); + std::string::size_type old_pos = 0; + std::string::size_type new_pos = 0; + + // Iterate -delimited values + value_t val; + while ((new_pos = str.find(sep, old_pos)) != std::string::npos) { + if (new_pos != old_pos) { + str_stream.width(new_pos - old_pos); + str_stream >> val; + vals.push_back(val); + } + + // skip over delimiter + str_stream.ignore(1); + old_pos = new_pos + 1; + } + + // Read last value + str_stream >> val; + vals.push_back(val); + } +}; + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8ace1e0a232ea7cccbb2089ec8432783c49410dd --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp @@ -0,0 +1,528 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +//-- BLAM_DEBUG_OUT --------------------------------------------------------- +#ifdef BLAM_DEBUG +# include +# ifndef BLAM_DEBUG_OUT +# define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl +# define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl +# endif // BLAM_DEBUG_OUT +#else +# ifndef BLAM_DEBUG_OUT +# define BLAM_DEBUG_OUT(msg) +# define BLAM_DEBUG_OUT_2(msg) +# endif // BLAM_DEBUG_OUT +#endif // BLAM_DEBUG + +// User could potentially define ComplexFloat/ComplexDouble instead of std:: +#ifndef BLAM_COMPLEX_TYPES +#define BLAM_COMPLEX_TYPES 1 +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(complex) + +namespace blam { +template +using Complex = cuda::std::complex; +using ComplexFloat = cuda::std::complex; +using ComplexDouble = cuda::std::complex; +} +#endif // BLAM_COMPLEX_TYPES + +// User could potentially define Half instead of cute:: +#ifndef BLAM_HALF_TYPE +#define BLAM_HALF_TYPE 1 +#include +namespace blam { +using Half = cute::half_t; +} +#endif // BLAM_HALF_TYPE + +namespace blam +{ +namespace cublas +{ + +inline const char* +cublas_get_error(cublasStatus_t status) +{ + switch (status) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized."; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library."; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function."; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture."; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed."; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute."; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed."; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported."; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing."; + default: + return "CUBLAS_ERROR -- "; + } +} + +inline bool +cublas_is_error(cublasStatus_t status) +{ + return status != CUBLAS_STATUS_SUCCESS; +} + + +// hgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* A, int ldA, + const Half* B, int ldB, + const Half* beta, + Half* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasHgemm"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), CUDA_R_16F, ldA, + reinterpret_cast(B), CUDA_R_16F, ldB, + reinterpret_cast(beta), + reinterpret_cast< __half*>(C), CUDA_R_16F, ldC, + CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// mixed hf gemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const Half* A, int ldA, + const Half* B, int ldB, + const float* beta, + float* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasGemmEx mixed half-float"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + alpha, + reinterpret_cast(A), CUDA_R_16F, ldA, + reinterpret_cast(B), CUDA_R_16F, ldB, + beta, + C, CUDA_R_32F, ldC, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// igemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const int32_t* alpha, + const int8_t* A, int ldA, + const int8_t* B, int ldB, + const int32_t* beta, + int32_t* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasIgemm"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + alpha, + A, CUDA_R_8I, ldA, + B, CUDA_R_8I, ldB, + beta, + C, CUDA_R_32I, ldC, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// sgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* A, int ldA, + const float* B, int ldB, + const float* beta, + float* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasSgemm"); + + return cublasSgemm(handle, transA, transB, + m, n, k, + alpha, + A, ldA, + B, ldB, + beta, + C, ldC); +} + +// dgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* A, int ldA, + const double* B, int ldB, + const double* beta, + double* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasDgemm"); + + return cublasDgemm(handle, transA, transB, + m, n, k, + alpha, + A, ldA, + B, ldB, + beta, + C, ldC); +} + +// cgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* A, int ldA, + const ComplexFloat* B, int ldB, + const ComplexFloat* beta, + ComplexFloat* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasCgemm"); + + return cublasCgemm(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, + reinterpret_cast(B), ldB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC); +} + +// zgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* A, int ldA, + const ComplexDouble* B, int ldB, + const ComplexDouble* beta, + ComplexDouble* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasZgemm"); + + return cublasZgemm(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, + reinterpret_cast(B), ldB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC); +} + +// hgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* A, int ldA, int loA, + const Half* B, int ldB, int loB, + const Half* beta, + Half* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasHgemmStridedBatched"); + + return cublasHgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast<__half*>(C), ldC, loC, + batch_size); +} + +// sgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* A, int ldA, int loA, + const float* B, int ldB, int loB, + const float* beta, + float* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasSgemmStridedBatched"); + + return cublasSgemmStridedBatched(handle, transA, transB, + m, n, k, + alpha, + A, ldA, loA, + B, ldB, loB, + beta, + C, ldC, loC, + batch_size); +} + +// dgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* A, int ldA, int loA, + const double* B, int ldB, int loB, + const double* beta, + double* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasDgemmStridedBatched"); + + return cublasDgemmStridedBatched(handle, transA, transB, + m, n, k, + alpha, + A, ldA, loA, + B, ldB, loB, + beta, + C, ldC, loC, + batch_size); +} + +// cgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* A, int ldA, int loA, + const ComplexFloat* B, int ldB, int loB, + const ComplexFloat* beta, + ComplexFloat* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasCgemmStridedBatched"); + + return cublasCgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC, loC, + batch_size); +} + +// zgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* A, int ldA, int loA, + const ComplexDouble* B, int ldB, int loB, + const ComplexDouble* beta, + ComplexDouble* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasZgemmStridedBatched"); + + return cublasZgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC, loC, + batch_size); +} + +// hgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* const A[], int ldA, + const Half* const B[], int ldB, + const Half* beta, + Half* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasHgemmBatched"); + + return cublasHgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(const_cast(A)), ldA, + // A, ldA, // cuBLAS 9.2 + reinterpret_cast(const_cast(B)), ldB, + // B, ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + reinterpret_cast<__half**>(const_cast(C)), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// sgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* const A[], int ldA, + const float* const B[], int ldB, + const float* beta, + float* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasSgemmBatched"); + + return cublasSgemmBatched(handle, transA, transB, + m, n, k, + alpha, + const_cast(A), ldA, + // A, ldA, // cuBLAS 9.2 + const_cast(B), ldB, + // B, ldB, // cuBLAS 9.2 + beta, + const_cast(C), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// dgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* const A[], int ldA, + const double* const B[], int ldB, + const double* beta, + double* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasDgemmBatched"); + + return cublasDgemmBatched(handle, transA, transB, + m, n, k, + alpha, + const_cast(A), ldA, + // A, ldA, // cuBLAS 9.2 + const_cast(B), ldB, + // B, ldB, // cuBLAS 9.2 + beta, + const_cast(C), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// cgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* const A[], int ldA, + const ComplexFloat* const B[], int ldB, + const ComplexFloat* beta, + ComplexFloat* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasCgemmBatched"); + + return cublasCgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + const_cast(reinterpret_cast(A)), ldA, + //reinterpret_cast(A), ldA, // cuBLAS 9.2 + const_cast(reinterpret_cast(B)), ldB, + //reinterpret_cast(B), ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + const_cast(reinterpret_cast(C)), ldC, + //reinterpret_cast(C), ldC, // cuBLAS 9.2 + batch_size); +} + +// zgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* const A[], int ldA, + const ComplexDouble* const B[], int ldB, + const ComplexDouble* beta, + ComplexDouble* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasZgemmBatched"); + + return cublasZgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + const_cast(reinterpret_cast(A)), ldA, + //reinterpret_cast(A), ldA, // cuBLAS 9.2 + const_cast(reinterpret_cast(B)), ldB, + //reinterpret_cast(B), ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + const_cast(reinterpret_cast(C)), ldC, + //reinterpret_cast(C), ldC, // cuBLAS 9.2 + batch_size); +} + +} // end namespace cublas +} // end namespace blam diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..88481a82e0e08f06b54c07c946d28160d41f9f07 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Contains code for debugging cutlass code +*/ + +#pragma once + +#include "device_dump.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/****************************************************************************** + * Debug and logging macros + ******************************************************************************/ + +/** + * Formats and prints the given message to stdout + */ +#if !defined(CUDA_LOG) +#if !defined(__CUDA_ARCH__) +#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__) +#else +#define CUDA_LOG(format, ...) \ + printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ + blockIdx.x, \ + blockIdx.y, \ + blockIdx.z, \ + threadIdx.x, \ + threadIdx.y, \ + threadIdx.z, \ + __VA_ARGS__); +#endif +#endif + +/** + * Formats and prints the given message to stdout only if DEBUG is defined + */ +#if !defined(CUDA_LOG_DEBUG) +#ifdef DEBUG +#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__) +#else +#define CUDA_LOG_DEBUG(format, ...) +#endif +#endif + +/** + * \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) + * along with the supplied source context. + * + * \return The CUDA error. + */ +__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error, + const char* expression, + const char* filename, + int line) { + (void)filename; + (void)line; + if (error) { +#if !defined(__CUDA_ARCH__) + fprintf( + stderr, "CUDA error %d [%s, %d] in expression '%s': %s\n", error, filename, line, expression, cudaGetErrorString(error)); + fflush(stderr); +#else + printf("CUDA error %d [%s, %d] in expression '%s'\n", error, filename, line, expression); +#endif + } + return error; +} + +/** + * \brief Perror macro + */ +#ifndef CUDA_PERROR +#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__) +#endif + +/** + * \brief Perror macro with exit + */ +#ifndef CUDA_PERROR_EXIT +#define CUDA_PERROR_EXIT(e) \ + do { if (cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)) { \ + exit(1); \ + } } while (0) +#endif + +/** + * \brief Perror macro only if DEBUG is defined + */ +#ifndef CUDA_PERROR_DEBUG +#ifdef DEBUG +#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e) +#else +#define CUDA_PERROR_DEBUG(e) (e) +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// A small helper class to dump a type at compile time +// Usage:: DumpType::Class +template +struct DebugType {}; + +template +void DebugTypeFunc(T const& t) { + T::t; +} + +// A small helper class to dump a compile time constant at compile time +// Usage: DumpValue::kConstant +template +struct DebugValue {}; diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h new file mode 100644 index 0000000000000000000000000000000000000000..a73a8cfe79dd22c2d298fcb3be8cf25d5e3f5734 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h @@ -0,0 +1,187 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/cutlass.h" + +/** + * \file + * \brief C++ interface to dump fragments and shared memory contents for + * debugging. + */ + +namespace cutlass { +namespace debug { + +/****************************************************************************** + * Dump the fragments + ******************************************************************************/ + +/// The first N threads dump the first M elements from their fragments with a +/// stride of S elements. If N is not specified, dump the data of all the +/// threads. If M is not specified, dump all the elements of the fragment. +template +CUTLASS_DEVICE void dump_fragment(Fragment const& frag, int N = 0, int M = 0, + int S = 1) { + int total_threads = blockDim.x * blockDim.y * blockDim.z; + int block_id = + blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + + (threadIdx.y * blockDim.x) + threadIdx.x; + + if (N < 0 || N > total_threads) { + if (thread_id == 0 && block_id == 0) + printf("Thread number N = %d should between [1, %d].\n", N, + total_threads); + + __syncthreads(); + + return; + } + + int total_elements = int(frag.size()); + + if (M < 0 || M > total_elements) { + if (thread_id == 0 && block_id == 0) + printf("Element number M = %d should between [1, %d].\n", M, + total_elements); + + __syncthreads(); + + return; + } + + if (N == 0) N = total_threads; + + if (M == 0) M = total_elements; + + if (S < 1 || S > M) { + if (thread_id == 0 && block_id == 0) + printf("Stride S = %d should between [1, %d].\n", S, M); + + __syncthreads(); + + return; + } + + if (thread_id == 0 && block_id == 0) + printf("\n*******************Dumping the fragments*******************\n\n"); + + CUTLASS_PRAGMA_NO_UNROLL + for (int tid = 0; tid < N; ++tid) { + if (tid == thread_id) { + printf("TB%d W%d T%d: ", block_id, tid / 32, tid & 31); + CUTLASS_PRAGMA_NO_UNROLL + for (int i = 0; i < M; i += S) { + printf("%.0f ", float(typename Fragment::value_type(frag[i]))); + } + printf("\n"); + } + + __syncthreads(); + } + + if (thread_id == 0 && block_id == 0) + printf("\n***********************************************************\n\n"); + + __syncthreads(); + + return; +} + +/****************************************************************************** + * Dump the shared memory + ******************************************************************************/ + +#define SHMEM_ROW_SIZE 128 + +/// Dump the shared memory contents. ptr is the begin address, size specifies +/// the number of elements that need to be dumped, and S specifies the stride. +template +CUTLASS_DEVICE void dump_shmem(Element const* ptr, size_t size, int S = 1) { + int block_id = + blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + + (threadIdx.y * blockDim.x) + threadIdx.x; + + if (ptr == nullptr) { + if (thread_id == 0 && block_id == 0) printf("ptr is null.\n"); + + __syncthreads(); + return; + } + + if (size < 1) { + if (thread_id == 0 && block_id == 0) + printf("Element size is less than 1\n"); + + __syncthreads(); + + return; + } + + int row_elements = SHMEM_ROW_SIZE / sizeof(Element); + + if (S < 1 || S > row_elements) { + if (thread_id == 0 && block_id == 0) + printf("Stride S = %d should between [1, %d].\n", S, row_elements); + + __syncthreads(); + + return; + } + + __syncthreads(); + + if (thread_id == 0) + printf("\n********Dumping the shared memory of TB %d*******\n\n", block_id); + + if (thread_id == 0) { + for (int i = 0; i < size; i += row_elements) { + for (int j = 0; j < row_elements; j += S) { + printf("%.0f ", float(ptr[i + j])); + } + + printf("\n"); + } + } + + if (thread_id == 0) + printf("\n***********************************************************\n\n"); + + __syncthreads(); + + return; +} +} // namespace debug +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..59457b2e8122f46e443844fe276b2c7fb35f3f56 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h @@ -0,0 +1,402 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C']. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do group norm on a device memory tensor with NHWC layout. + * \tparam T: data type + */ +template +void groupnorm(cutlass::Tensor4DCoord input_size, + const int num_groups, + const float eps, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream); + +extern __shared__ char groupnorm_shm[]; + +// For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory, +// we store the input in the shared memory. +// grid(num_groups, dim0) +// block(BLOCKSIZE) +// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group +template +__global__ void groupnorm_twopass_store_locally(T* output, + const T* input, + const T* gamma, + const T* beta, + int num_groups, + int prod_dim1_to_last_dim, + int last_dim, + const float eps, + const int TVecs_PER_THREAD) +{ + const int bid = blockIdx.y; // index of batch + const int gid = blockIdx.x; // index of group + const int tid = threadIdx.x; // index of thread + const int bdimx = blockDim.x; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int v_reduce_elements = s_reduce_elements / T_PER_TVec; + const int s_group_stride = last_dim / num_groups; + const int v_group_stride = s_group_stride / T_PER_TVec; + const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; + const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; + TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; + T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid; + float local_sum[1] = {0.0f}; + +// load from global memory into shared memory +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); + const int local_val_offset = i * T_PER_TVec; +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + local_sum[0] += tmp; + local_val[local_val_offset + j] = tmp_vec_ptr[j]; + } + } + } + __shared__ float s_mean, s_variance; + + // reduction for mean + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_mean = local_sum[0] / s_reduce_elements; + } + __syncthreads(); + + // reduction for std + local_sum[0] = 0.0f; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int local_val_offset = i * T_PER_TVec; +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(local_val[local_val_offset + j]); + tmp -= s_mean; + local_sum[0] += tmp * tmp; + } + } + } + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); + } + __syncthreads(); + + // normalize + const int gamma_offset_of_group = gid * v_group_stride; + const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; + const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; + const int local_val_offset = i * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; + TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; + T* gamma_val_ptr = (T*)(&gamma_val); + T* beta_val_ptr = (T*)(&beta_val); + TVec tmp_vec; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = (static_cast(local_val[local_val_offset + j]) - s_mean) * s_variance + * static_cast(gamma_val_ptr[j]) + + static_cast(beta_val_ptr[j]); + if (sizeof(T) == sizeof(half)) { + tmp_vec_ptr[j] = T(__float2half_rn(tmp)); + } + else { + tmp_vec_ptr[j] = T(tmp); + } + } + output_TVec_ptr[offset_in_group] = tmp_vec; + } + } +} + +// For large prod_dim1_to_last_dim/num_groups, +// in which the data cannot be stored locally, +// we will load from global memory multiple times, +// grid(num_groups, dim0) +// block(BLOCKSIZE) +// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group +template +__global__ void groupnorm_twopass_multiple_load(T* output, + const T* input, + const T* gamma, + const T* beta, + int num_groups, + int prod_dim1_to_last_dim, + int last_dim, + const float eps, + const int TVecs_PER_THREAD) +{ + const int bid = blockIdx.y; // index of batch + const int gid = blockIdx.x; // index of group + const int tid = threadIdx.x; // index of thread + const int bdimx = blockDim.x; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int v_reduce_elements = s_reduce_elements / T_PER_TVec; + const int s_group_stride = last_dim / num_groups; + const int v_group_stride = s_group_stride / T_PER_TVec; + const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; + const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; + TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; + float local_sum[1] = {0.0f}; + +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + local_sum[0] += tmp; + } + } + } + __shared__ float s_mean, s_variance; + + // reduction for mean + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_mean = local_sum[0] / s_reduce_elements; + } + __syncthreads(); + + // reduction for std + local_sum[0] = 0.0f; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + tmp -= s_mean; + local_sum[0] += tmp * tmp; + } + } + } + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); + } + __syncthreads(); + + // normalize + const int gamma_offset_of_group = gid * v_group_stride; + const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; + const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; + TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; + TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; + T* gamma_val_ptr = (T*)(&gamma_val); + T* beta_val_ptr = (T*)(&beta_val); + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); + TVec output_tmp_vec; + T* output_tmp_vec_ptr = (T*)(&output_tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = + (static_cast(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast(gamma_val_ptr[j]) + + static_cast(beta_val_ptr[j]); + if (sizeof(T) == sizeof(half)) { + output_tmp_vec_ptr[j] = T(__float2half_rn(tmp)); + } + else { + output_tmp_vec_ptr[j] = T(tmp); + } + } + output_TVec_ptr[offset_in_group] = output_tmp_vec; + } + } +} + +//ref_input & ref_output should be [N, H, W, C] +//ref_gamma & ref_beta should be [1, 1, 1, C] +template +void groupnorm(cutlass::Tensor4DCoord input_size, + const int num_groups, + const float eps, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream){ + const int N = input_size.n(); + const int H = input_size.h(); + const int W = input_size.w(); + const int C = input_size.c(); + if (C % num_groups != 0){ + printf("[ERROR] C should be a multiple of num_groups.\n"); + } + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* gamma = ref_gamma.data(); + const T* beta = ref_beta.data(); + + const int dim0 = N; + const int last_dim = C; + const int prod_dim1_to_last_dim = H*W*C; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int s_group_stride = last_dim / num_groups; + dim3 grid(num_groups, dim0); + int threadblock_size = 32; + if (s_group_stride % 2 == 0) { + const int T_PER_TVec = 2; + while (threadblock_size < 1024) { + if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) + break; + threadblock_size *= 2; + } + dim3 block(threadblock_size); + const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; + const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); + // for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32; + // the size of grid & block may have better choice for different cases. + // ensure shared memory is smaller than 48KB + if (std::is_same::value){ + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + else{ + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + } + else { + const int T_PER_TVec = 1; + while (threadblock_size < 1024) { + if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) + break; + threadblock_size *= 2; + } + dim3 block(threadblock_size); + const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; + const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h new file mode 100644 index 0000000000000000000000000000000000000000..0fcbf5cb0f4bf3152a708c6e3845e89fd214cfac --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h @@ -0,0 +1,644 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do layernorm on a device memory tensor with RowMajor layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do layernorm on a device memory tensor with RowMajor layout. + * \tparam T: data type + */ +template +void layernorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream); + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e1(T* output, + const T* input, + const T* gamma, + const T* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + T local_val[ITEM_PER_THREAD]; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + const T zero = T(0.0f); + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + local_val[i] = index < n ? input[index] : zero; + local_sums[0] += static_cast(local_val[i]); + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + if (index < n){ + const float tmp = static_cast(local_val[i]) - s_mean; + local_sums[0] += tmp * tmp; + } + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + if (index < n) { + const T gamma_val = gamma[index]; + const T beta_val = beta[index]; + output[index] = T((static_cast(local_val[i]) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e2(T2* output, + const T2* input, + const T2* gamma, + const T2* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + T2 local_val[ITEM_PER_THREAD]; + const int n_2 = n / 2; + int offset = m_idx * n_2; + input += offset; + output += offset; + + const T2 zero = {T(0.0f), T(0.0f)}; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + local_val[i] = index < n_2 ? input[index] : zero; + local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_2){ + const float2 tmp = {static_cast(local_val[i].x) - s_mean, + static_cast(local_val[i].y) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; + } + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_2){ + const T2 gamma_val = gamma[index]; + const T2 beta_val = beta[index]; + T2 tmp; + tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + output[index] = tmp; + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*4 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e4(T4* output, + const T4* input, + const T4* gamma, + const T4* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + T4 local_val[ITEM_PER_THREAD]; + const int n_4 = n / 4; + int offset = m_idx * n_4; + input += offset; + output += offset; + + const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)}; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + local_val[i] = index < n_4 ? input[index] : zero; + local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y) + + static_cast(local_val[i].z) + static_cast(local_val[i].w); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_4){ + const float4 tmp = {static_cast(local_val[i].x) - s_mean, + static_cast(local_val[i].y) - s_mean, + static_cast(local_val[i].z) - s_mean, + static_cast(local_val[i].w) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z + tmp.w * tmp.w; + } + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_4){ + const T4 gamma_val = gamma[index]; + const T4 beta_val = beta[index]; + T4 tmp; + tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + tmp.z = T((static_cast(local_val[i].z) - s_mean)*s_variance*static_cast(gamma_val.z) + static_cast(beta_val.z)); + tmp.w = T((static_cast(local_val[i].w) - s_mean)*s_variance*static_cast(gamma_val.w) + static_cast(beta_val.w)); + output[index] = tmp; + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements +*/ +template +__global__ void layernorm_twoPassAlgo_e1(T* output, + const T* input, + const T* gamma, + const T* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_sums[0] += local_val; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_val = local_val - s_mean; + local_sums[0] += local_val * local_val; + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + for (int index = tid ; index < n ; index += bdimx){ + const T gamma_val = gamma[index]; + const T beta_val = beta[index]; + const T local_val = input[index]; + output[index] = T((static_cast(local_val) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_e2(T2* output, + const T2* input, + const T2* gamma, + const T2* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + const int n_2 = n / 2; + int offset = m_idx * n_2; + input += offset; + output += offset; + + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + local_sums[0] += static_cast(local_val.x) + static_cast(local_val.y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + const float2 tmp = {static_cast(local_val.x) - s_mean, + static_cast(local_val.y) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + const T2 gamma_val = gamma[index]; + const T2 beta_val = beta[index]; + T2 tmp; + tmp.x = T((static_cast(local_val.x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val.y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + output[index] = tmp; + } +} + +template +void layernorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream){ + const int m = tensor_size.row(); + const int n = tensor_size.column(); + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* gamma = ref_gamma.data(); + const T* beta = ref_beta.data(); + dim3 grid(m); + dim3 block((n + 31)/32*32); + if (block.x > 1024){ + block.x = 1024; + } + // TODO : There should be better configs for different cases, we only use several samples to show how to use here + // TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels. + if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) { + block.x = (n/4 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e4<<>>( + (float4*)output, + (const float4*)input, + (const float4*)gamma, + (const float4*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e4<<>>( + (half4*)output, + (const half4*)input, + (const half4*)gamma, + (const half4*)beta, + m, + n); + } + } //if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) + else if (n % 2 == 0) { + if (n / 2 <= 1024) { + block.x = (n/2 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } //if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n / 2 <= 1024) + else if (n <= 8192) { + block.x = ((n + 7)/8 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 8192) + else if (n <= 16384) { + block.x = ((n + 15)/ 16 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 16384) + else if (n <= 32768) { + block.x = ((n + 31)/32 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 32768) + else { + if (block.x > 512) + block.x = 512; + if (std::is_same::value) { + layernorm_twoPassAlgo_e2<<>>( + (float2 *)output, + (const float2 *)input, + (const float2 *)gamma, + (const float2 *)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_e2<<>>( + (half2 *)output, + (const half2 *)input, + (const half2 *)gamma, + (const half2 *)beta, + m, + n); + } + } + } // if (n % 2 == 0) + else { + if (n <= 1024) { + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 1024) + else if (n <= 8192) { + block.x = ((n + 7)/8 + 31)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 8192) + else if (n <= 16384) { + block.x = ((n + 15)/16 + 32)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 16384) + else if (n <= 32768) { + block.x = ((n + 31)/32 + 31)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 32768) + else{ + if (block.x > 512) { + block.x = 512; + } + layernorm_twoPassAlgo_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } + } +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..44f6a467a5d0938289e4bc127cddc13b9aeabdf3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h @@ -0,0 +1,375 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief C++ interface to CUDA device memory management functions. + */ + +#include +#include + +#include "cutlass/platform/platform.h" +#include "cutlass/numeric_types.h" +#include "cutlass/trace.h" +#include "exceptions.h" + +namespace cutlass { +namespace device_memory { + +/****************************************************************************** + * Allocation lifetime + ******************************************************************************/ + +/// Allocate a buffer of \p count elements of type \p T on the current CUDA device +template +T* allocate(size_t count = 1) { + + T* ptr = 0; + size_t bytes = count * sizeof_bits::value / 8; + + cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); + + if (cuda_error != cudaSuccess) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 0) + std::ostringstream os; + os << "cutlass::device_memory::allocate: cudaMalloc failed: bytes=" << bytes; + CUTLASS_TRACE_HOST(os.str()); +#endif + throw cuda_exception("Failed to allocate memory", cuda_error); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + std::ostringstream os; + os << "cutlass::device_memory::allocate: Successful cudaMalloc: bytes=" << bytes; + CUTLASS_TRACE_HOST(os.str()); + } +#endif + + return ptr; +} + +/// Free the buffer pointed to by \p ptr +template +void free(T* ptr) { + if (ptr) { + cudaError_t cuda_error = (cudaFree(ptr)); + if (cuda_error != cudaSuccess) { + throw cuda_exception("Failed to free device memory", cuda_error); + } + } +} + +/****************************************************************************** + * Data movement + ******************************************************************************/ + +template +void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) { + size_t bytes = count * sizeof_bits::value / 8; + if (bytes == 0 && count > 0) { + bytes = 1; + } + cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind)); + if (cuda_error != cudaSuccess) { + std::ostringstream os; + os << "cutlass::device_memory::copy: cudaMemcpy() failed: " + << "dst=" << dst << ", src=" << src + << ", bytes=" << bytes << ", count=" << count; + if (kind == cudaMemcpyHostToDevice) { + os << ", kind=cudaMemcpyHostToDevice"; + } + else if (kind == cudaMemcpyDeviceToHost) { + os << ", kind=cudaMemcpyDeviceToHost"; + } + else if (kind == cudaMemcpyDeviceToDevice) { + os << ", kind=cudaMemcpyDeviceToDevice"; + } + else if (kind == cudaMemcpyHostToHost) { + os << ", kind=cudaMemcpyHostToHost"; + } + else if (kind == cudaMemcpyDefault) { + os << ", kind=cudaMemcpyDefault"; + } + else { + os << ", kind=Unknown"; + } + os << ", error: " << cudaGetErrorString(cuda_error); + + throw cuda_exception(os.str().c_str(), cuda_error); + } +} + +template +void copy_to_device(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyHostToDevice); +} + +template +void copy_to_host(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyDeviceToHost); +} + +template +void copy_device_to_device(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyDeviceToDevice); +} + +template +void copy_host_to_host(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyHostToHost); +} + +/// Copies elements from device memory to host-side range +template +void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) { + size_t elements = end - begin; + copy_to_host(&*begin, device_begin, elements); +} + +/// Copies elements to device memory from host-side range +template +void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) { + size_t elements = end - begin; + copy_to_device(device_begin, &*begin, elements); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DeviceAllocation { +public: + + /// Delete functor for CUDA device memory + struct deleter { + void operator()(T* ptr) { + cudaError_t cuda_error = (cudaFree(ptr)); + if (cuda_error != cudaSuccess) { + // noexcept + // throw cuda_exception("cudaFree() failed", cuda_error); + return; + } + } + }; + +public: + // + // Data members + // + + /// Number of elements of T allocated on the current CUDA device + size_t capacity; + + /// Smart pointer + platform::unique_ptr smart_ptr; + +public: + + // + // Static methods + // + + /// Static member to compute the number of bytes needed for a given number of elements + static size_t bytes(size_t elements) { + if (sizeof_bits::value < 8) { + size_t const kElementsPerByte = 8 / sizeof_bits::value; + return elements / kElementsPerByte; + } + else { + size_t const kBytesPerElement = sizeof_bits::value / 8; + return elements * kBytesPerElement; + } + } + +public: + + // + // Methods + // + + /// Constructor: allocates no memory + DeviceAllocation() : capacity(0) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device + DeviceAllocation(size_t _capacity) : + smart_ptr(device_memory::allocate(_capacity)), capacity(_capacity) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation + DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {} + + /// Copy constructor + DeviceAllocation(DeviceAllocation const &p): + smart_ptr(device_memory::allocate(p.capacity)), capacity(p.capacity) { + + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + } + + /// Move constructor + DeviceAllocation(DeviceAllocation &&p): capacity(0) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + } + + /// Destructor + ~DeviceAllocation() { reset(); } + + /// Returns a pointer to the managed object + T* get() const { return smart_ptr.get(); } + + /// Releases the ownership of the managed object (without deleting) and resets capacity to zero + T* release() { + capacity = 0; + return smart_ptr.release(); + } + + /// Deletes the managed object and resets capacity to zero + void reset() { + capacity = 0; + smart_ptr.reset(); + } + + /// Deletes managed object, if owned, and allocates a new object + void reset(size_t _capacity) { + reset(device_memory::allocate(_capacity), _capacity); + } + + /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity + void reset(T* _ptr, size_t _capacity) { + smart_ptr.reset(_ptr); + capacity = _capacity; + } + + /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released. + void reallocate(size_t new_capacity) { + + platform::unique_ptr new_allocation(device_memory::allocate(new_capacity)); + + device_memory::copy_device_to_device( + new_allocation.get(), + smart_ptr.get(), + std::min(new_capacity, capacity)); + + std::swap(smart_ptr, new_allocation); + std::swap(new_capacity, capacity); + } + + /// Returns the number of elements + size_t size() const { + return capacity; + } + + /// Returns the number of bytes needed to store the allocation + size_t bytes() const { + return bytes(capacity); + } + + /// Returns a pointer to the object owned by *this + T* operator->() const { return smart_ptr.get(); } + + /// Returns the deleter object which would be used for destruction of the managed object. + deleter& get_deleter() { return smart_ptr.get_deleter(); } + + /// Returns the deleter object which would be used for destruction of the managed object (const) + const deleter& get_deleter() const { return smart_ptr.get_deleter(); } + + /// Copies a device-side memory allocation + DeviceAllocation & operator=(DeviceAllocation const &p) { + if (capacity != p.capacity) { + smart_ptr.reset(device_memory::allocate(p.capacity)); + capacity = p.capacity; + } + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + return *this; + } + + /// Move assignment + DeviceAllocation & operator=(DeviceAllocation && p) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + return *this; + } + + /// Copies the entire allocation from another location in device memory. + void copy_from_device(T const *ptr) const { + copy_from_device(ptr, capacity); + } + + /// Copies a given number of elements from device memory + void copy_from_device(T const *ptr, size_t elements) const { + device_memory::copy_device_to_device(get(), ptr, elements); + } + + void copy_to_device(T *ptr) const { + copy_to_device(ptr, capacity); + } + + void copy_to_device(T *ptr, size_t elements) const { + device_memory::copy_device_to_device(ptr, get(), elements); + } + + void copy_from_host(T const *ptr) const { + copy_from_host(ptr, capacity); + } + + void copy_from_host(T const *ptr, size_t elements) const { + device_memory::copy_to_device(get(), ptr, elements); + } + + void copy_to_host(T *ptr) const { + copy_to_host(ptr, capacity); + } + + void copy_to_host(T *ptr, size_t elements) const { + device_memory::copy_to_host(ptr, get(), elements); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace device_memory { + +/// Device allocation abstraction that tracks size and capacity +template +using allocation = cutlass::DeviceAllocation; + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h new file mode 100644 index 0000000000000000000000000000000000000000..8e38029951d27c0be8da059b59d2a83fe2762ef1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h @@ -0,0 +1,141 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to transform a device memory tensor from NCHW layout to NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface to transform a device memory tensor from NCHW layout to NHWC layout. + * \tparam T: data type + */ +template +void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + +template +__global__ void nchw_to_nhwc_kernel(T *output, + const T *input, + const int n, + const int h, + const int w, + const int c) { + const int hw = h*w; + const int chw = c*hw; + __shared__ T shbuf[32 * (32 + 1)]; + const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; + const int32_t wid = tid / 32; + const int32_t lid = tid % 32; + const int32_t ni = blockIdx.z; + const int32_t ci0 = blockIdx.y * 32; + const int32_t hwi0 = blockIdx.x * 32; + + const size_t input_idx = ni * chw + (ci0 + wid) * hw + hwi0; + const T *A = input + input_idx; + if (hwi0 + lid < hw) { + const int lid_x_33 = lid * 33; + if ((ci0 + 32) <= c) { + int ci = wid; // between 0 and 7 + CUTLASS_PRAGMA_UNROLL + for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { + shbuf[lid_x_33 + ci] = A[lid]; + A = &A[8 * hw]; + ci += 8; + } + } else { + for (int ci = wid; ci < 32; ci += 8) { + if ((ci + ci0) < c) { + shbuf[lid_x_33 + ci] = A[lid]; + } + A = &A[8 * hw]; + } + } + } + __syncthreads(); + + const int32_t ciOut = ci0 + lid; + output = &output[ni * chw + ciOut]; + if (ciOut < c) { + if (hwi0 + 32 < hw) { + int hwI = wid; + CUTLASS_PRAGMA_UNROLL + for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { + output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; + hwI += 8; + } + } else { + for (int hwI = wid; hwI < 32; hwI += 8) { + if (hwi0 + hwI < hw) { + output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; + } + } + } + } +} + +template +void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream) { + + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.c() == output_tensor_size.h() && + input_tensor_size.h() == output_tensor_size.w() && + input_tensor_size.w() == output_tensor_size.c()); + + int n = output_tensor_size.n(); + int h = output_tensor_size.h(); + int w = output_tensor_size.w(); + int c = output_tensor_size.c(); + + dim3 grid((h*w + 31)/32, (c + 31)/32, n); + dim3 block(32, 8); + nchw_to_nhwc_kernel<<>>(ref_output.data(), ref_input.data(), + n, h, w, c); +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h new file mode 100644 index 0000000000000000000000000000000000000000..f58da62a35350b4a865f4521ec1cbb76ae87e874 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h @@ -0,0 +1,276 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels for padding in device memory with NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface for padding in a device memory tensor with NHWC layout + * \tparam T: data type + */ +template +void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + + +template +__global__ void nhwc_padding_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const int32_t c_in, + const int32_t c_out, + const T zero, + const T *input, + T *output){ + + const int32_t idx_jump = blockDim.x * gridDim.x; + const int32_t total_elements = n * h * w * c_out; + + int32_t c_idx, w_idx, h_idx, n_idx, resudial; + + T value; + for (int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += idx_jump) { + + c_idx = idx%c_out; + if (c_idx >= c_in){ + value = zero; + } + else{ + resudial = idx/c_out; + w_idx = resudial%w; + resudial = resudial/w; + h_idx = resudial%h; + n_idx = resudial/h; + resudial = ((n_idx * h + h_idx) * w + w_idx) * c_in + c_idx; + value = input[resudial]; + } + output[idx] = value; + } +} + + +// fast kernel for c_in = 3 & c_out = 4 +template +__global__ void nhwc_padding_channel_3To4_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const Tio *input, + Tio *output, + const int32_t max_output_element, + const int32_t max_input_element, + const Tio zero_io, + const Telement zero_element){ + __shared__ Tio shm[192]; + const int tidx = blockIdx.x * 192 + threadIdx.x; + const int threadidx = threadIdx.x; + + shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; + __syncthreads(); + + const int output_offset = blockIdx.x * 256; + const int lower_bound = max_output_element < output_offset + 256 ? max_output_element : output_offset + 256; + for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) + { + const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4; + Telement array[element_in_Tio]; + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k]; + output[i] = *((const Tio *)array); + } +} + +// fast kernel for c_in = 3 & c_out = 8 +template +__global__ void nhwc_padding_channel_3To8_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const Tio *input, + Tio *output, + const int32_t max_output_element, + const int32_t max_input_element, + const Tio zero_io, + const Telement zero_element){ + __shared__ Tio shm[192]; + const int tidx = blockIdx.x * 192 + threadIdx.x; + const int threadidx = threadIdx.x; + + shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; + __syncthreads(); + + const int output_offset = blockIdx.x * 512; + const int lower_bound = max_output_element < output_offset + 512 ? max_output_element : output_offset + 512; + for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) + { + const Telement* shm_element = (const Telement*)shm + (element_in_Tio == 4 ? j/2 : j)*3; + Telement array[element_in_Tio]; + //float + if (element_in_Tio == 4){ + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = ((j % 2) == 1) ? zero_element : ((k >= 3) ? zero_element : shm_element[k]); + } + //half + else{ + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = (k >= 3) ? zero_element : shm_element[k]; + } + output[i] = *((const Tio *)array); + } +} + +template +void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream){ + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.h() == output_tensor_size.h() && + input_tensor_size.w() == output_tensor_size.w() && + input_tensor_size.c() <= output_tensor_size.c()); + + int n = input_tensor_size.n(); + int h = input_tensor_size.h(); + int w = input_tensor_size.w(); + int c_in = input_tensor_size.c(); + int c_out = output_tensor_size.c(); + + //case 1 : channel == 3 padding to 4 or 8 + if ((c_out == 4 || c_out == 8) && c_in == 3 && (n*h*w % 8 == 0)){ + dim3 block(192); + const int nhw = n*h*w; + const int nhwc = nhw*c_in; + //for half_t + if (cutlass::sizeof_bits::value == 16){ + const int element_in_Tio = 8; + const int max_input_element = nhwc/element_in_Tio; + const int max_output_element = nhw*c_out/element_in_Tio; + const int4 zero_io = {0, 0, 0, 0}; + const half_t zero_element = static_cast(0.0f); + dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); + if (c_out == 4){ + nhwc_padding_channel_3To4_kernel<<>> + (n, h, w, + (const int4 *)ref_input.data(), + (int4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + else if (c_out == 8){ + nhwc_padding_channel_3To8_kernel<<>> + (n, h, w, + (const int4 *)ref_input.data(), + (int4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + } + //for float + else{ + const int element_in_Tio = 4; + const int max_input_element = nhwc/element_in_Tio; + const int max_output_element = nhw*c_out/element_in_Tio; + const float4 zero_io = {0.0f, 0.0f, 0.0f, 0.0f}; + const float zero_element = 0.0f; + dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); + if (c_out == 4){ + nhwc_padding_channel_3To4_kernel<<>> + (n, h, w, + (const float4 *)ref_input.data(), + (float4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + else if (c_out == 8){ + nhwc_padding_channel_3To8_kernel<<>> + (n, h, w, + (const float4 *)ref_input.data(), + (float4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + } + } + //case 2 : even channel + else if ((c_out % 2) == 0 && (c_in % 2) == 0){ + int32_t total_elements = n * h * w * c_out / 2; + int block_size = 256; + dim3 grid((total_elements + 255)/256); + dim3 block(block_size); + //for half_t + if (cutlass::sizeof_bits::value == 16){ + const __half2 zero = {0.0f, 0.0f}; + nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const __half2*)ref_input.data(), (__half2*)ref_output.data()); + } + //for float + else{ + const float2 zero = {0.0f, 0.0f}; + nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const float2*)ref_input.data(), (float2*)ref_output.data()); + } + } + //case 3 : odd channel + else{ + int32_t total_elements = n * h * w * c_out; + int block_size = 256; + dim3 grid((total_elements + 255)/256); + dim3 block(block_size); + const T zero = static_cast(0.0f); + nhwc_padding_kernel<<>>(n, h, w, c_in, c_out, zero, ref_input.data(), ref_output.data()); + } +} + + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..5633456c1412ff41366ec4c6ec5c3e6e3a2d6c19 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h @@ -0,0 +1,573 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do avg/max pooling on a device memory tensor with NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do avg/max pooling on a device memory tensor with NHWC layout. + * \tparam T: data type + */ +template +void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord filter_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + cutlass::MatrixCoord padding, + cutlass::MatrixCoord stride, + TensorRef ref_input, + TensorRef ref_output, + int poolingType, //0 for avg pooling ; 1 for max pooling + cudaStream_t stream); + +/** get the output size of pooling + */ +inline int getOutputSize(int H_W, int padding, int kernel_size, int stride) +{ + return (H_W + 2 * padding - kernel_size) / stride + 1; +} + +/** + * input is [N, H, W, C] + * assume stride == kernel_size + * output_h = (H + 2*padding_H - kernel_H)/stride_H + * output_w = (W + 2*padding_W - kernel_W)/stride_W + * output is [N, output_h, output_w, C] + * grid(N, output_h, output_w) + * block(min(C, 256)) : + * each block deals with C elements of output when each thread deals with ((C + 255)/256 element of output) +*/ +template +__global__ void pooling_nhwc_element1_kernel(T* output, + const T* input, + const int N, + const int H, + const int W, + const int C, + const int output_H, + const int output_W, + const int kernel_H, + const int kernel_W, + const int stride_H, + const int stride_W, + const int padding_H, + const int padding_W) +{ + const int tid = threadIdx.x; + const int n_idx = blockIdx.x; + const int output_h_idx = blockIdx.y; + const int output_w_idx = blockIdx.z; + + int h_start_idx = output_h_idx * stride_H - padding_H; + int h_end_idx = h_start_idx + kernel_H; + h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; + h_end_idx = h_end_idx > H ? H : h_end_idx; + + int w_start_idx = output_w_idx * stride_W - padding_W; + int w_end_idx = w_start_idx + kernel_W; + w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; + w_end_idx = w_end_idx > W ? W : w_end_idx; + + input += n_idx * H * W * C; + output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; + const int kernel_size2 = kernel_H * kernel_W; + for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { + float pooling; + if (IS_AVG_POOLING){ + pooling = 0.0f; + } + else{ + pooling = -FLT_MAX; + } + for (int h = h_start_idx; h < h_end_idx; h++) { + for (int w = w_start_idx; w < w_end_idx; w++) { + const int idx = (h * W + w) * C; + const float tmp = static_cast(input[idx + c_idx]); + if (IS_AVG_POOLING){ + pooling = pooling + tmp; + } + else{ + pooling = pooling > tmp ? pooling : tmp; + } + } + } + + T output_val; + if (IS_AVG_POOLING){ + output_val = T(pooling/kernel_size2); + } + else{ + output_val = T(pooling); + } + output[c_idx] = output_val; + } +} + +template +__global__ void pooling_nhwc_element2_kernel(T2* output, + const T2* input, + const int N, + const int H, + const int W, + const int C, + const int output_H, + const int output_W, + const int kernel_H, + const int kernel_W, + const int stride_H, + const int stride_W, + const int padding_H, + const int padding_W) +{ + const int tid = threadIdx.x; + const int n_idx = blockIdx.x; + const int output_h_idx = blockIdx.y; + const int output_w_idx = blockIdx.z; + + int h_start_idx = output_h_idx * stride_H - padding_H; + int h_end_idx = h_start_idx + kernel_H; + h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; + h_end_idx = h_end_idx > H ? H : h_end_idx; + + int w_start_idx = output_w_idx * stride_W - padding_W; + int w_end_idx = w_start_idx + kernel_W; + w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; + w_end_idx = w_end_idx > W ? W : w_end_idx; + + input += n_idx * H * W * C; + output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; + const int kernel_size2 = kernel_H * kernel_W; + for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { + float2 pooling; + if (IS_AVG_POOLING) { + pooling = {0.0f, 0.0f}; + } + else { + pooling = {-FLT_MAX, -FLT_MAX}; + } + for (int h = h_start_idx; h < h_end_idx; h++) { + for (int w = w_start_idx; w < w_end_idx; w++) { + const int idx = (h * W + w) * C; + const T2 tmp = input[idx + c_idx]; + const float2 tmp_flt2 = {static_cast(tmp.x), static_cast(tmp.y)}; + if (IS_AVG_POOLING) { + pooling.x += tmp_flt2.x; + pooling.y += tmp_flt2.y; + } + else { + pooling.x = pooling.x > tmp_flt2.x ? pooling.x : tmp_flt2.x; + pooling.y = pooling.y > tmp_flt2.y ? pooling.y : tmp_flt2.y; + } + } + } + + T2 output_val; + if (IS_AVG_POOLING) { + output_val.x = T(pooling.x/kernel_size2); + output_val.y = T(pooling.y/kernel_size2); + } + else { + output_val.x = T(pooling.x); + output_val.y = T(pooling.y); + } + output[c_idx] = output_val; + } +} + +/** + * output [N, 1, 1, C] + * input [N, H, W, C] + * grid(C, N) + * block(block_size) -- each block deals with H*W/block_size elements; +*/ +template +__global__ void pooling_nxhTo1x1_element1_kernel( + T* output, const T* input, const int N, const int HW, const int C) +{ + const int c_idx = blockIdx.x; + const int n_idx = blockIdx.y; + float pooling[1]; + if (IS_AVG_POOLING) { + pooling[0] = 0.0f; + } + else { + pooling[0] = -FLT_MAX; + } + const size_t input_offset = n_idx * HW * C + c_idx; + input += input_offset; + const size_t output_offset = n_idx * C + c_idx; + output += output_offset; + int tid = threadIdx.x; + + for (int index = tid; index < HW; index += blockDim.x) { + float val = static_cast(input[index * C]); + if (IS_AVG_POOLING) { + pooling[0] += val; + } + else { + pooling[0] = pooling[0] > val ? pooling[0] : val; + } + } + if (blockDim.x <= 32) { + if (IS_AVG_POOLING) { + warpReduceSum(pooling); + } + else { + warpReduceMax(pooling); + } + } + else { + if (IS_AVG_POOLING) { + blockReduceSum(pooling); + } + else { + blockReduceMax(pooling); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + T output_val; + if (IS_AVG_POOLING) { + output_val = T(pooling[0] / HW); + } + else { + output_val = T(pooling[0]); + } + output[0] = output_val; + } +} + + +/** + * output [N, 1, 1, C] + * input [N, H, W, C] + * grid(C/2, N) + * block(block_size) -- each thread deals with H*W/block_size * 2 elements; +*/ +template +__global__ void pooling_nxhTo1x1_element2_kernel( + T2* output, const T2* input, const int N, const int HW, const int C) +{ + const int c_idx = blockIdx.x; + const int n_idx = blockIdx.y; + float pooling[2]; + if (IS_AVG_POOLING) { + pooling[0] = pooling[1] = 0.0f; + } + else { + pooling[0] = pooling[1] = -FLT_MAX; + } + const int C_2 = C / 2; + const size_t input_offset = n_idx * HW * C_2 + c_idx; + input += input_offset; + const size_t output_offset = n_idx * C_2 + c_idx; + output += output_offset; + int tid = threadIdx.x; + + for (int index = tid; index < HW; index += blockDim.x) { + T2 val = input[index * C_2]; + float2 val_flt2 = {static_cast(val.x), static_cast(val.y)}; + if (IS_AVG_POOLING) { + pooling[0] += val_flt2.x; + pooling[1] += val_flt2.y; + } + else { + pooling[0] = pooling[0] > val_flt2.x ? pooling[0] : val_flt2.x; + pooling[1] = pooling[1] > val_flt2.y ? pooling[1] : val_flt2.y; + } + } + if (blockDim.x <= 32) { + if (IS_AVG_POOLING) { + warpReduceSum(pooling); + } + else { + warpReduceMax(pooling); + } + } + else { + if (IS_AVG_POOLING) { + blockReduceSum(pooling); + } + else { + blockReduceMax(pooling); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + T2 output_val; + if (IS_AVG_POOLING) { + output_val.x = T(pooling[0] / HW); + output_val.y = T(pooling[1] / HW); + } + else { + output_val.x = T(pooling[0]); + output_val.y = T(pooling[1]); + } + output[0] = output_val; + } +} + +template +void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord filter_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + cutlass::Tensor4DCoord padding, + cutlass::MatrixCoord stride, + TensorRef ref_input, + TensorRef ref_output, + int poolingType, //0 for avg pooling ; 1 for max pooling + cudaStream_t stream) { + + assert(input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.c() == output_tensor_size.c()); + + const int N = input_tensor_size.n(); + const int H = input_tensor_size.h(); + const int W = input_tensor_size.w(); + const int C = input_tensor_size.c(); + const int padding_H = padding.h(); + const int padding_W = padding.w(); + const int kernel_H = filter_tensor_size.h(); + const int kernel_W = filter_tensor_size.w(); + const int stride_H = stride.row(); + const int stride_W = stride.column(); + + const int output_H = getOutputSize(H, padding_H, kernel_H, stride_H); + const int output_W = getOutputSize(W, padding_W, kernel_W, stride_W); + + assert(output_tensor_size.h() == output_H && + output_tensor_size.w() == output_W); + + if (C % 2 != 0) { + if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { + dim3 grid(C, N); + dim3 block(256); + if (H*W < block.x){ + block.x = (H*W + 31)/32*32; + } + if (poolingType == 0) { + pooling_nxhTo1x1_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H*W, + C); + } // if (poolingType == 0) + else { + pooling_nxhTo1x1_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H*W, + C); + } + } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) + else { + dim3 grid(N, output_H, output_W); + dim3 block(256); + if (C < block.x) { + block.x = C; + } + if (poolingType == 0) { + pooling_nhwc_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H, + W, + C, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (poolingType == 0) + else { + pooling_nhwc_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H, + W, + C, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } + } // if (C % 2 != 0)) + else { + if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { + dim3 grid(C/2, N); + dim3 block(256); + if (H*W < block.x){ + block.x = (H*W + 31)/32*32; + } + if (poolingType == 0) { + if (std::is_same::value) { + pooling_nxhTo1x1_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H*W, + C); + } // if (std::is_same::value) + else { + pooling_nxhTo1x1_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H*W, + C); + } + } // if (poolingType == 0) + else { + if (std::is_same::value) { + pooling_nxhTo1x1_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H*W, + C); + } // if (std::is_same::value) + else { + pooling_nxhTo1x1_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H*W, + C); + } + } + } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) + else { + dim3 grid(N, output_H, output_W); + dim3 block(256); + if (C/2 < block.x) { + block.x = C/2; + } + if (poolingType == 0) { + if (std::is_same::value) { + pooling_nhwc_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (std::is_same::value) + else { + pooling_nhwc_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } // if (poolingType == 0) + else { + if (std::is_same::value) { + pooling_nhwc_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (std::is_same::value) + else { + pooling_nhwc_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } + } + } +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h new file mode 100644 index 0000000000000000000000000000000000000000..babfecd39205ebff39794133868e4a95b7e9525c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h @@ -0,0 +1,144 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to transform a device memory tensor from NHWC layout to NCHW layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface to transform a device memory tensor from NHWC layout to NCHW layout. + * \tparam T: data type + */ +template +void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + + +template +__global__ void nhwc_to_nchw_kernel(T *output, + const T *input, + const int n, + const int h, + const int w, + const int c) { + + const int hw = h*w; + const int hwc = hw*c; + __shared__ T shbuf[32 * (32 + 1)]; + const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; + const int32_t wid = tid / 32; + const int32_t lid = tid % 32; + const int32_t ni = blockIdx.z; + const int32_t hwi0 = blockIdx.y * 32; + const int32_t ci0 = blockIdx.x * 32; + + const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0; + const T *A = input + input_idx; + if (ci0 + lid < c) { + const int lid_x_33 = lid * 33; + if ((hwi0 + 32) <= hw) { + int hwi = wid; // between 0 and 7 + CUTLASS_PRAGMA_UNROLL + for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { + shbuf[lid_x_33 + hwi] = A[lid]; + A = &A[8 * c]; + hwi += 8; + } + } else { + for (int hwi = wid; hwi < 32; hwi += 8) { + if ((hwi + hwi0) < hw) { + shbuf[lid_x_33 + hwi] = A[lid]; + } + A = &A[8 * c]; + } + } + } + __syncthreads(); + + const int32_t hwiOut = hwi0 + lid; + output = &output[ni * hwc + hwiOut]; + if (hwiOut < hw) { + if (ci0 + 32 < c) { + int cI = wid; + CUTLASS_PRAGMA_UNROLL + for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { + output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; + cI += 8; + } + } else { + for (int cI = wid; cI < 32; cI += 8) { + if (ci0 + cI < c) { + output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; + } + } + } + } +} + +template +void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream) { + + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.h() == output_tensor_size.c() && + input_tensor_size.w() == output_tensor_size.h() && + input_tensor_size.c() == output_tensor_size.w()); + + int n = input_tensor_size.n(); + int h = input_tensor_size.h(); + int w = input_tensor_size.w(); + int c = input_tensor_size.c(); + + dim3 grid((c + 31)/32, (h*w + 31)/32, n); + dim3 block(32, 8); + nhwc_to_nchw_kernel<<>>(ref_output.data(), ref_input.data(), + n, h, w, c); + +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..0d1b1af56e4463640edc3e9c82533baf815c9b27 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h @@ -0,0 +1,186 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/device_utils.h" +#include + +namespace cutlass { + +__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input, + const float4 *weight, + const int m, const int n, float epsilon) { + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + const int n_8 = n / 8; + int offset = m_idx * n_8; + input += offset; + output += offset; + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const half2 *h1 = (half2 *)&local_val.x; + const half2 *h2 = (half2 *)&local_val.y; + const half2 *h3 = (half2 *)&local_val.z; + const half2 *h4 = (half2 *)&local_val.w; + local_sums[0] += static_cast(h1->x) * static_cast(h1->x) + + static_cast(h1->y) * static_cast(h1->y) + + static_cast(h2->x) * static_cast(h2->x) + + static_cast(h2->y) * static_cast(h2->y) + + static_cast(h3->x) * static_cast(h3->x) + + static_cast(h3->y) * static_cast(h3->y) + + static_cast(h4->x) * static_cast(h4->x) + + static_cast(h4->y) * static_cast(h4->y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + epsilon); + } + __syncthreads(); + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const float4 weight_val = weight[index]; + + const half2 *l1 = (half2 *)&local_val.x; + const half2 *l2 = (half2 *)&local_val.y; + const half2 *l3 = (half2 *)&local_val.z; + const half2 *l4 = (half2 *)&local_val.w; + + const half2 *g1 = (half2 *)&weight_val.x; + const half2 *g2 = (half2 *)&weight_val.y; + const half2 *g3 = (half2 *)&weight_val.z; + const half2 *g4 = (half2 *)&weight_val.w; + + float4 tmp; + half2 *h1 = (half2 *)&tmp.x; + half2 *h2 = (half2 *)&tmp.y; + half2 *h3 = (half2 *)&tmp.z; + half2 *h4 = (half2 *)&tmp.w; + + h1->x = half(static_cast(l1->x) * s_mean * static_cast(g1->x)); + h1->y = half(static_cast(l1->y) * s_mean * static_cast(g1->y)); + h2->x = half(static_cast(l2->x) * s_mean * static_cast(g2->x)); + h2->y = half(static_cast(l2->y) * s_mean * static_cast(g2->y)); + h3->x = half(static_cast(l3->x) * s_mean * static_cast(g3->x)); + h3->y = half(static_cast(l3->y) * s_mean * static_cast(g3->y)); + h4->x = half(static_cast(l4->x) * s_mean * static_cast(g4->x)); + h4->y = half(static_cast(l4->y) * s_mean * static_cast(g4->y)); + + output[index] = tmp; + } +} + +template +__global__ void rmsnorm_twoPassAlgo_e1(T* output, + const T* input, + const T* weight, + const int m, const int n, + float epsilon) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_sums[0] += local_val * local_val; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + epsilon); + } + __syncthreads(); + + for (int index = tid ; index < n ; index += bdimx){ + const T weight_val = weight[index]; + const T local_val = input[index]; + output[index] = T(static_cast(local_val) * s_mean * static_cast(weight_val)); + } +} + +template +void rmsnorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_weight, + cudaStream_t stream, float epsilon = 1e-5f){ + const int m = tensor_size.row(); + const int n = tensor_size.column(); + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* weight = ref_weight.data(); + dim3 grid(m); + + if (n % 8 == 0 && std::is_same::value) { + dim3 block(cutlass::platform::min(1024, (n / 8 + 31) / 32 * 32)); + + rmsnorm_twoPassAlgo_e8<<>>( + (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon); + } else { + dim3 block(cutlass::platform::min(1024, ((n + 31)/32 + 31)/32*32)); + + rmsnorm_twoPassAlgo_e1<<>>( + output, input, weight, m, n, epsilon); + } + + auto result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(result) << std::endl; + abort(); + } +} + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9747d50975d7d35df287f6b056aedc489adb317c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief utils code for device cutlass code +*/ + +#pragma once + +#include +#include +#define FINAL_MASK 0xffffffff + +struct half4 { + half x, y, z, w; +}; + +template +__inline__ __device__ T warpReduceSum(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSum(T* val) +{ + __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSum(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSum(val); + return (T)0.0f; +} + +template +__inline__ __device__ T warpReduceMax(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceMax(T* val) +{ + static __shared__ T shared[32][NUM]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[wid][i] = val[i]; + } + } + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[lane][i] : (T)(-FLT_MAX); + } + warpReduceMax(val); + + return (T)0.0f; +} + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h new file mode 100644 index 0000000000000000000000000000000000000000..6565aba9607ad68defacb6e98d9f9bbc944cd48d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief This header contains a class to parametrize a statistical distribution function. +*/ + +#include + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Distribution type +struct Distribution { + /// Variant types + enum Kind { Invalid, Uniform, Gaussian, Identity, Sequential, AllZeros, AllOnes }; + + /// Distribution state + union { + /// Uniform distribution + struct { + double min; + double max; + // Percent elements set to NaN + double pnan; + } uniform; + + /// Gaussian distribution + struct { + double mean; + double stddev; + double pnz; + double pnzA; + double pnzB; + double pnzC; + } gaussian; + + /// Elements are linear combination of row and column index + struct { + double start; + double delta; + } sequential; + }; + + /// Active variant kind + Kind kind; + + /// Random values are cast to integer after scaling by this power of two + int int_scale; + + // + // Methods + // + + Distribution() : kind(Invalid), int_scale(0) {} + +/// Configures distribution as uniform random + Distribution &set_uniform(double _min, double _max, int _int_scale = 0, double _pnan = 0) { + kind = Uniform; + uniform.min = _min; + uniform.max = _max; + int_scale = _int_scale; + uniform.pnan = _pnan; + return *this; + } + + /// Configures distribution as Gaussian distribution + Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0, double _pnz = 1.0) { + kind = Gaussian; + gaussian.mean = _mean; + gaussian.stddev = _stddev; + gaussian.pnz = _pnz; + gaussian.pnzA = _pnz; + gaussian.pnzB = _pnz; + gaussian.pnzC = _pnz; + int_scale = _int_scale; + return *this; + } + + /// Sets identity + Distribution &set_identity() { + kind = Identity; + return *this; + } + + /// Sets sequential + Distribution &set_sequential(double start, double delta, int _int_scale = 0) { + kind = Sequential; + sequential.start = start; + sequential.delta = delta; + int_scale = _int_scale; + return *this; + } +}; + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints a Distribution to ostream +inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const &dist) { + switch (dist.kind) { + case cutlass::Distribution::Uniform: + out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max + << ", pnan: " << dist.uniform.pnan; + break; + case cutlass::Distribution::Gaussian: + out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev + << ", pnzA: " << dist.gaussian.pnzA << ", pnzB: " + << dist.gaussian.pnzB << ", pnzC: " << dist.gaussian.pnzC; + break; + case cutlass::Distribution::Identity: + out << "identity"; + break; + case cutlass::Distribution::Sequential: + out << "sequential"; + break; + default: + out << "unknown"; + } + + out << ", int_scale: " << dist.int_scale; + + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h new file mode 100644 index 0000000000000000000000000000000000000000..f2b7df6cb1c465a312d76566768cb79fcdfffee4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief C++ exception semantics for CUDA error codes + */ + +#include +#include +#include + +#include "cutlass/platform/platform.h" + +namespace cutlass { + +/// C++ exception wrapper for CUDA \p cudaError_t +class cuda_exception : public std::exception { + public: + /// Constructor + cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {} + + /// Returns the underlying CUDA \p cudaError_t + cudaError_t cudaError() const { return err; } + + protected: + /// Explanatory string + const char* msg; + + /// Underlying CUDA \p cudaError_t + cudaError_t err; +}; + +/// Writes a cuda_exception instance to an output stream +inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) { + return out << e.what() << ": " << cudaGetErrorString(e.cudaError()); +} + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..be2264466e350c062900a50e27e923847186d084 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp @@ -0,0 +1,369 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT command line parser to gather semantic modes, their stride order, and extents. +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +namespace cutlass { + +// Output shortcuts +std::ostream& operator<<(std::ostream& os, std::vector data) { + for (auto& a : data) os << a; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, std::vector data) { + for (auto& a : data) os << a << " "; + return os; +} + +struct GettCommandLine { + struct GettProblem { + using extent_type = int; + using stride_type = int64_t; + + // Row modes: appear in A and C/D + std::vector M; + std::vector ldAm; + std::vector ldCm; + + // Column modes: appear in B and C/D + std::vector N; + std::vector ldBn; + std::vector ldCn; + + // Reduction modes: appear in A and B + std::vector K; + std::vector ldAk; + std::vector ldBk; + + // Batch modes: appear in all in/out tensors + std::vector L; + std::vector ldAl; + std::vector ldBl; + std::vector ldCl; + }; + + static GettProblem + parse(int argc, char const* argv[], bool parse_verbose = false) { + using extent_type = typename GettProblem::extent_type; + using stride_type = typename GettProblem::stride_type; + + cutlass::CommandLine cmd(argc, argv); + + // modeA + std::vector a_mode; + cmd.get_cmd_line_arguments("modeA", a_mode); + + // modeB + std::vector b_mode; + cmd.get_cmd_line_arguments("modeB", b_mode); + + // modeC + std::vector c_mode; + cmd.get_cmd_line_arguments("modeC", c_mode); + + + // mode_sizes + std::map mode_size; + // First, initialize all modes in a, b, c to make sure they're in map + for (char a : a_mode) mode_size[a] = 1; + for (char b : b_mode) mode_size[b] = 1; + for (char c : c_mode) mode_size[c] = 1; + + // Then, overwrite the ones in -extent + std::vector > extent_tokens; + cmd.get_cmd_line_argument_pairs("extents", extent_tokens); + for (auto e : extent_tokens) { + if (std::get<0>(e).size() > 1) { + std::cerr << "ERROR: Mode name must only be 1 character long.\n"; + print_usage(); + exit(1); + } + char label = std::get<0>(e)[0]; + int size = std::stoi(std::get<1>(e)); + mode_size[label] = size; + } + + // Print out symbolic modes and their extents + if (parse_verbose) { + std::cout << "C_" << c_mode << " = A_" << a_mode << " * B_" << b_mode << "\n"; + for (auto e : mode_size) std::cout << " " << std::get<0>(e) << " : " << std::get<1>(e) << "\n"; + } + + // + // Collect/Compute strides + // + + std::map mode_ldA; + std::map mode_ldB; + std::map mode_ldC; + + { + stride_type current; + + current = 1; + for (char a : a_mode) { mode_ldA[a] = current; current *= mode_size[a]; } + + current = 1; + for (char b : b_mode) { mode_ldB[b] = current; current *= mode_size[b]; } + + current = 1; + for (char c : c_mode) { mode_ldC[c] = current; current *= mode_size[c]; } + } + + // + // Collect mode categories + // + + std::vector row_mode; // rows + std::vector col_mode; // columns + std::vector red_mode; // reductions + std::vector bat_mode; // batches + + { + std::vector a_label = a_mode; + std::vector b_label = b_mode; + std::vector c_label = c_mode; + + std::sort(std::begin(a_label), std::end(a_label)); + std::sort(std::begin(b_label), std::end(b_label)); + std::sort(std::begin(c_label), std::end(c_label)); + + // std::set_intersections to find semantic category of each symbolic mode + std::set_intersection(std::begin(a_label), std::end(a_label), + std::begin(c_label), std::end(c_label), + std::back_inserter(row_mode)); + + std::set_intersection(std::begin(b_label), std::end(b_label), + std::begin(c_label), std::end(c_label), + std::back_inserter(col_mode)); + + std::set_intersection(std::begin(a_label), std::end(a_label), + std::begin(b_label), std::end(b_label), + std::back_inserter(red_mode)); + + std::set_intersection(std::begin(row_mode), std::end(row_mode), + std::begin(col_mode), std::end(col_mode), + std::back_inserter(bat_mode)); + + // std::set_difference to remove batch modes from other semantic modes + for (char l : bat_mode) { + row_mode.erase(std::remove(std::begin(row_mode), std::end(row_mode), l), std::end(row_mode)); + col_mode.erase(std::remove(std::begin(col_mode), std::end(col_mode), l), std::end(col_mode)); + red_mode.erase(std::remove(std::begin(red_mode), std::end(red_mode), l), std::end(red_mode)); + } + } + + // Print out the semantic association of each symbolic mode + if (parse_verbose) { + std::cout << " rows : " << row_mode << '\n'; + std::cout << " cols : " << col_mode << '\n'; + std::cout << " reds : " << red_mode << '\n'; + std::cout << " bats : " << bat_mode << '\n'; + } + + // + // Permute modes + // + + // Permute the batched modes to promote coalescing + // Sort the batched modes by min(ldAl,ldBl) and in case of a tie by the size + std::sort(std::begin(bat_mode), std::end(bat_mode), [&](char l1, char l2) { + return std::tie(std::min(mode_ldA[l1],mode_ldB[l1]),mode_size[l1]) + < std::tie(std::min(mode_ldA[l2],mode_ldB[l2]),mode_size[l2]); + }); + // Compute sizes and strides of ordered reduction modes + std::vector L; + std::vector ldAl; + std::vector ldBl; + std::vector ldCl; + for (char l : bat_mode) { + L.push_back(mode_size[l]); + ldAl.push_back(mode_ldA[l]); + ldBl.push_back(mode_ldB[l]); + ldCl.push_back(mode_ldC[l]); + } + + // Permute the reduction modes to promote coalescing + // Sort the reduction modes by min(ldAk,ldBk) and in case of a tie by the size + std::sort(std::begin(red_mode), std::end(red_mode), [&](char k1, char k2) { + return std::tie(std::min(mode_ldA[k1],mode_ldB[k1]),mode_size[k1]) + < std::tie(std::min(mode_ldA[k2],mode_ldB[k2]),mode_size[k2]); + }); + // Compute sizes and strides of ordered reduction modes + std::vector K; + std::vector ldAk; + std::vector ldBk; + for (char k : red_mode) { + K.push_back(mode_size[k]); + ldAk.push_back(mode_ldA[k]); + ldBk.push_back(mode_ldB[k]); + } + + // Permute the row modes to promote coalescing + // Sort the row modes by min(ldAm,ldCm) and in case of a tie by ldAm + std::sort(std::begin(row_mode), std::end(row_mode), [&](char m1, char m2) { + return std::tie(std::min(mode_ldA[m1],mode_ldC[m1]),mode_ldA[m1]) + < std::tie(std::min(mode_ldA[m2],mode_ldC[m2]),mode_ldA[m2]); + }); + // Compute sizes and strides of ordered row modes + std::vector M; + std::vector ldAm; + std::vector ldCm; + for (char m : row_mode) { + M.push_back(mode_size[m]); + ldAm.push_back(mode_ldA[m]); + ldCm.push_back(mode_ldC[m]); + } + + // Permute the col modes to promote coalescing + // Sort the col modes by min(ldBn,ldCn) and in case of a tie by ldBn + std::sort(std::begin(col_mode), std::end(col_mode), [&](char n1, char n2) { + return std::tie(std::min(mode_ldB[n1],mode_ldC[n1]),mode_ldB[n1]) + < std::tie(std::min(mode_ldB[n2],mode_ldC[n2]),mode_ldB[n2]); + }); + // Compute sizes and strides of ordered col modes + std::vector N; + std::vector ldBn; + std::vector ldCn; + for (char n : col_mode) { + N.push_back(mode_size[n]); + ldBn.push_back(mode_ldB[n]); + ldCn.push_back(mode_ldC[n]); + } + + if (parse_verbose) { + std::cout << "C_"; + if (! row_mode.empty()) { + std::cout << "(" << row_mode << ")"; + } + if (! col_mode.empty()) { + std::cout << "(" << col_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << " = A_"; + if (! row_mode.empty()) { + std::cout << "(" << row_mode << ")"; + } + if (! red_mode.empty()) { + std::cout << "(" << red_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << " * B_"; + if (! col_mode.empty()) { + std::cout << "(" << col_mode << ")"; + } + if (! red_mode.empty()) { + std::cout << "(" << red_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << '\n'; + + int M_size = std::accumulate(std::begin(M), std::end(M), 1, std::multiplies<>{}); + int N_size = std::accumulate(std::begin(N), std::end(N), 1, std::multiplies<>{}); + int K_size = std::accumulate(std::begin(K), std::end(K), 1, std::multiplies<>{}); + int L_size = std::accumulate(std::begin(L), std::end(L), 1, std::multiplies<>{}); + + std::cout << " M : (" << M_size << ") "; + for (char m : row_mode) std::cout << m << ":" << mode_size[m] << " "; + std::cout << '\n'; + std::cout << " N : (" << N_size << ") "; + for (char n : col_mode) std::cout << n << ":" << mode_size[n] << " "; + std::cout << '\n'; + std::cout << " K : (" << K_size << ") "; + for (char k : red_mode) std::cout << k << ":" << mode_size[k] << " "; + std::cout << '\n'; + std::cout << " L : (" << L_size << ") "; + for (char l : bat_mode) std::cout << l << ":" << mode_size[l] << " "; + std::cout << '\n'; + + std::cout << " ldAm : " << ldAm << '\n'; + std::cout << " ldAk : " << ldAk << '\n'; + std::cout << " ldAl : " << ldAl << '\n'; + std::cout << " ldBn : " << ldBn << '\n'; + std::cout << " ldBk : " << ldBk << '\n'; + std::cout << " ldBl : " << ldBl << '\n'; + std::cout << " ldCm : " << ldCm << '\n'; + std::cout << " ldCn : " << ldCn << '\n'; + std::cout << " ldCl : " << ldCl << '\n'; + } + + return {M, ldAm, ldCm, + N, ldBn, ldCn, + K, ldAk, ldBk, + L, ldAl, ldBl, ldCl}; + } + + static void + print_usage() { + std::cout << + "GETT problem command line parser:\n" + " --modeA=\n" + " A comma delimited list of characters that correspond to the row, reduction, and batch modes in A tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --modeB=\n" + " A comma delimited list of characters that correspond to the column, reduction, and batch modes in B tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --modeC=\n" + " A comma delimited list of characters that correspond to the row, column, and batch modes in B tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --extents=\n" + " A command delimited list of symbolic mode and its corresponding extent.\n" + " Extents are defaulted to 1 if any are not provided.\n\n" + + "Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096\n"; + } +}; + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp new file mode 100644 index 0000000000000000000000000000000000000000..58d08b860c9e665d170fd022ed0d95875e029019 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp @@ -0,0 +1,116 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +namespace cute +{ + +void +device_init(int device_id, bool quiet = false) +{ + cudaDeviceProp device_prop; + std::size_t device_free_physmem; + std::size_t device_total_physmem; + + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + + //float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000; + + if (!quiet) { + printf("Using device %d: %s (SM%d, %d SMs)\n", + device_id, device_prop.name, + device_prop.major * 10 + device_prop.minor, + device_prop.multiProcessorCount); + fflush(stdout); + } +} + +/** + * Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores. + */ +inline int +_ConvertSMVer2Cores(int major, int minor) +{ + // Defines for GPU Architecture types (using the SM version to determine + // the # of cores per SM + typedef struct { + int SM; // 0xMm (hexadecimal notation), M = SM Major version, + // and m = SM minor version + int Cores; + } sSMtoCores; + + sSMtoCores nGpuArchCoresPerSM[] = { + {0x30, 192}, + {0x32, 192}, + {0x35, 192}, + {0x37, 192}, + {0x50, 128}, + {0x52, 128}, + {0x53, 128}, + {0x60, 64}, + {0x61, 128}, + {0x62, 128}, + {0x70, 64}, + {0x72, 64}, + {0x75, 64}, + {-1, -1}}; + + int index = 0; + + while (nGpuArchCoresPerSM[index].SM != -1) { + if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { + return nGpuArchCoresPerSM[index].Cores; + } + index++; + } + + // If we don't find the values, we default use the previous one + // to run properly + printf("MapSMtoCores for SM %d.%d is undefined." + " Default to use %d Cores/SM\n", + major, minor, nGpuArchCoresPerSM[index - 1].Cores); + + return nGpuArchCoresPerSM[index - 1].Cores; +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h new file mode 100644 index 0000000000000000000000000000000000000000..4e7718059dfaea0c77d7ebf67789f307b4ca0cf6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief reorder data from the host side +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { + +/// This is needed for the interleaved integer tensor core kernels. The purpose +/// is to use skip the shared memory part in the epilogue. +template +void reorder_column(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + const int InstructionShapeCol = 8; + // 4 threads per Quad + const int ElementsPerThread = InstructionShapeCol / 4; + // 4 threads per Quad + const int ReorderedElementsPerThread = + Interleaved / 4; + + for (int n = 0; n < problem_size.n(); n++) { + for (int k = 0; k < problem_size.k(); k++) { + dest.at({k, (n / Interleaved) * Interleaved + + ((n % ReorderedElementsPerThread) / ElementsPerThread) * + InstructionShapeCol + + ((n % Interleaved) / ReorderedElementsPerThread) * + ElementsPerThread + + (n % ElementsPerThread)}) = src.at({k, n}); + } + } +} + +template +void reorder_convK(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + + TensorRef> mappedDest(dest.data(), dest.stride(0)); + TensorRef> mappedSrc(src.data(), src.stride(0)); + + reorder_column( + mappedDest, mappedSrc, problem_size); +} + +/// This is needed for the sparse tensor core kernels. The purpose +/// is to use ldmatrix to load from shared memory to the register file. +template +void reorder_meta(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + for (int m = 0; m < problem_size.m(); m++) { + for (int k = 0; k < problem_size.k(); k++) { + // First reorder the rows. + int group = (sizeof(Element) == 2) ? 32 : 16; + int interweave = (sizeof(Element) == 2) ? 4 : 2; + + int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8; + int dest_col = k; + + // Next swizzle the 2x2 blocks from Z to N. + if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) { + ++dest_row; + --dest_col; + } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) { + --dest_row; + ++dest_col; + } + + dest.at({dest_row, dest_col}) = src.at({m, k}); + } + } +} +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..3226055ad0836e7a3059340ff16d54594987e0c8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensor { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRef; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorView; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + +private: + using StorageUnit = typename platform::conditional_t, uint8_t, // Avoid the std::vector specialization + typename platform::conditional_t::value % 8 == 0, // Handle subbyte types + Element, uint8_t>>; + using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; + static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits; + static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements; + static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes; + static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit; + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + /// number of containers + size_t count_to_container_storage_unit_count(size_t count) { + return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit; + } + +public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensor() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensor( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensor( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensor() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")"); +#endif + + device_.reset(); + host_.clear(); + + size_t count_container = count_to_container_storage_unit_count(count); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")"); +#endif + host_.resize(count_container); + + // Allocate memory + StorageUnit* device_memory = nullptr; + if (device_backed_) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")"); +#endif + device_memory = device_memory::allocate(count_container); + } + device_.reset(device_memory, device_backed_ ? count_container : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_))); + + if (static_cast(new_size_container) > host_.size()) { + reserve(new_size, device_backed_); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the logical number of elements stored in the host tensor + size_t size() const { + return layout_.capacity(extent_); + } + + /// Returns the logical capacity in terms of number of elements. May be larger than the size(). + LongIndex capacity() const { + return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements; + } + + /// Gets pointer to host data + Element * host_data() { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to device data + Element * device_data() { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data + Element const * device_data() const { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + + /// Returns the layout object + Layout & layout() { + return layout_; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_.data(), device_.get(), device_.size()); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_.get(), host_.data(), host_.size()); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + host_.data(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + device_.get(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + device_.get(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + host_.data(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + reinterpret_cast(ptr_host), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + reinterpret_cast(ptr_device), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + reinterpret_cast(ptr_device), host_.data(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + reinterpret_cast(ptr_host), host_.data(), container_count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..ca770e4d76cfe2df16309baca0b2de8ab6de98c4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h @@ -0,0 +1,591 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/tensor_ref_planar_complex.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensorPlanarComplex { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRefPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorViewPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + + private: + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensorPlanarComplex() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensorPlanarComplex( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensorPlanarComplex( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensorPlanarComplex() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated + + device_.reset(); + host_.clear(); + + host_.resize(count * 2); + + // Allocate memory + Element* device_memory = nullptr; + if (device_backed_) { + device_memory = device_memory::allocate(count * 2); + } + device_.reset(device_memory, device_backed_ ? count * 2 : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + + if (static_cast(new_size * 2) > host_.size()) { + reserve(new_size); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the number of elements stored in the host tensor + size_t size() const { + return host_.size() / 2; + } + + /// Returns the logical capacity based on extent and layout. May differ from size(). + LongIndex capacity() const { + return layout_.capacity(extent_); + } + + /// Stride between real and imaginary parts + LongIndex imaginary_stride() const { + return host_.size() / 2; + } + + /// Gets pointer to host data + Element * host_data() { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element * host_data_imag() { return host_.data() + imaginary_stride(); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; } + + /// Gets pointer to host data with a pointer offset + Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element const * host_data_imag() const { return host_.data() + imaginary_stride(); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to device data + Element * device_data() { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; } + + /// Gets pointer to device data + Element const * device_data() const { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; } + + /// Gets a pointer to the device data imaginary part + Element * device_data_imag() { return device_.get() + imaginary_stride(); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { + return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_real() { + return cutlass::TensorRef(host_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_imag() { + return cutlass::TensorRef(host_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { + return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_real() { + return cutlass::TensorRef(device_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_imag() { + return cutlass::TensorRef(device_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_real() { + return cutlass::TensorView(host_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_imag() { + return cutlass::TensorView(host_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_real() { + return cutlass::TensorView(device_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_imag() { + return cutlass::TensorView(device_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_data(), device_data(), imaginary_stride() * 2); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_data(), host_data(), imaginary_stride() * 2); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device_real, ///< source device memory + Element const* ptr_device_imag, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_host( + host_data(), ptr_device_real, count); + + device_memory::copy_to_host( + host_data_imag(), ptr_device_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device_real, ///< source device memory + Element const* ptr_device_imag, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_device_to_device( + device_data(), ptr_device_real, count); + + device_memory::copy_device_to_device( + device_data_imag(), ptr_device_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host_real, ///< source host memory + Element const* ptr_host_imag, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_device( + device_data(), ptr_host_real, count); + + device_memory::copy_to_device( + device_data_imag(), ptr_host_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host_real, ///< source host memory + Element const* ptr_host_imag, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_host_to_host( + host_data(), ptr_host_real, count); + + device_memory::copy_host_to_host( + host_data_imag(), ptr_host_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host_real, ///< source device memory + Element * ptr_host_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_host( + ptr_host_real, device_data(), count); + + device_memory::copy_to_host( + ptr_host_imag, device_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device_real, ///< source device memory + Element * ptr_device_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_device_to_device( + ptr_device_real, device_data(), count); + + device_memory::copy_device_to_device( + ptr_device_imag, device_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device_real, ///< source device memory + Element * ptr_device_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_device( + ptr_device_real, host_data(), count); + + device_memory::copy_to_device( + ptr_device_imag, host_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host_real, ///< source host memory + Element * ptr_host_imag, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_host_to_host( + ptr_host_real, host_data(), count); + + device_memory::copy_host_to_host( + ptr_host_imag, host_data_imag(), count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h new file mode 100644 index 0000000000000000000000000000000000000000..9cd62927432c65ce1f0187f46306f7e1198a1182 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief uncompress sparse matrix from the host side +*/ +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { + +// uncompress sparse tensor core A matrix +template +void uncompress(TensorRef uncompressed_tensor_a, + TensorRef tensor_a, + TensorRef tensor_e, int row, int col) { + // How many uncompressed data we can get with ElementE meta data + int DecompressedElementsPerElementE = + 256 / cutlass::sizeof_bits::value; + + // Process 4bit meta data a time + int step; + + // 1:2 or 2:4 or 4:8 + int a, b; + + if (cutlass::sizeof_bits::value == 4) { + step = 8; + a = 4; + b = 8; + } else if (cutlass::sizeof_bits::value == 8) { + step = 4; + a = 2; + b = 4; + } else if (cutlass::sizeof_bits::value == 16) { + step = 4; + a = 2; + b = 4; + } else if (cutlass::sizeof_bits::value == 32) { + step = 2; + a = 1; + b = 2; + } + + int ElementsPerE = (cutlass::sizeof_bits::value == 4) ? 2 : 1; + + for (int r = 0; r < row; ++r) { + for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) { + + ElementE meta = tensor_e.at(MatrixCoord(r, c)); + + for (int i = 0; i < DecompressedElementsPerElementE; i += step) { + int e = (meta >> (i / step * 4)) & 0xf; + int idx0 = e & 0x3; + int idx1 = e >> 2; + + if (a == 1) idx0 = idx0 / 2; + + for (int ii = 0; ii < step; ii += ElementsPerE) { + int real_col = + c * DecompressedElementsPerElementE + i + ii; + int compressed_col = (real_col / b) * a; + + if (ii == (idx0 * ElementsPerE)) { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + tensor_a.at(MatrixCoord(r, compressed_col)); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + tensor_a.at(MatrixCoord(r, compressed_col + 1)); + } else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE)); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + tensor_a.at( + MatrixCoord(r, compressed_col + ElementsPerE + 1)); + } else { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + ElementA(0); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + ElementA(0); + } + } + } + } + } +} + +// uncompress ELL block sparse matrix +template +void uncompress_ell_block_sparse( + TensorRef uncompressed_tensor_a, + TensorRef tensor_a, + TensorRef ell_idx, + int rows, int cols, + int ell_num_cols, int ell_blocksize) { + + for (int r = 0; r < rows / ell_blocksize; ++r) { + for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) { + + ElementE idx = ell_idx.at(MatrixCoord(r, c)); + + if (idx != -1) { + int row_begin = r * ell_blocksize; + int col_begin_real = idx * ell_blocksize; + int col_begin = c * ell_blocksize; + + for (int i = 0; i < ell_blocksize; ++i) { + for (int j = 0; j < ell_blocksize; ++j) { + uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) = + tensor_a.at( + MatrixCoord(row_begin + i, col_begin +j)); + } + } + } + } + } +} + +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h new file mode 100644 index 0000000000000000000000000000000000000000..6b72b043fc0c1271cf9f12e5cb9a81d29659cb0a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h @@ -0,0 +1,38 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +// integer_sequence moved to cutlass/numeric_types.h + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..43f5a3f92d29f229703cc4c5f9071c11d0f89df4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -0,0 +1,472 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for mixed input data type kernels. +*/ + +#pragma once + +#include +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/arch/mma_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cute/util/type_traits.hpp" + +namespace cutlass { + +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleBroadCastLayout, + class ThrLayout> +__global__ void dequantize_kernel(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleBroadCastLayout const broadcasted_scale_layout, + ThrLayout thr_layout) { + using namespace cute; + + // Represent the full tensors to gmem elements. + // These are expected to have shape [MN, K, L] + cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout); + cute::Tensor gmem_op_q = cute::make_tensor(cute::make_gmem_ptr(q_buffer), operand_layout); + // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting + // It is expected that K % G == 0 + cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); + cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout); + + // Assign 1 thread per element in the thread block + auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); // + auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L) + + // Tile across the block + auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord); + auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord); + auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord); + auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord); + + auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x); + auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x); + auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x); + auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x); + + // Make a fragment of registers to hold gmem loads + cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0)); + cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0)); + cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0)); + cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0)); + cute::Tensor rmem_op_scaled = cute::make_fragment_like(rmem_op_dq); + cute::Tensor rmem_zero_buf = cute::make_fragment_like(rmem_zero); + + cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout)); + auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord); + auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x); + + const auto num_iters = cute::size<3>(tOpDq_gOpDq); + + for (int ii = 0; ii < num_iters; ++ii) { + const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii)); + if (thread_offset < cute::size<0>(operand_layout)) { + cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q); + cute::copy(tScale_gScale(_, _, _, ii), rmem_scale); + cute::copy(tZero_gZero(_, _, _, ii), rmem_zero); + cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } ); + cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } ); + cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{}); + cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{}); + cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } ); + cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii)); + } + } +} + +template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleLayout> +static void dequantize(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleLayout const scale_layout, + int const group_size, + cudaStream_t &stream) { + using namespace cute; + + constexpr int tpb = 128; + auto thr_layout = make_layout(make_shape(Int{})); + + const auto num_rows = get<0>(shape(operand_layout)); + const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L] + const auto batches = get<2>(shape(operand_layout)); // [MN, K, L] + const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L] + + if (num_rows != size<0>(scale_layout)) { + std::cerr << "Invalid first dimension for scales. Must match first dim for weights." + << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) + << std::endl; + exit(-1); + } + + const auto scale_stride0 = get<0>(stride(scale_layout)); + const auto scale_stride1 = get<1>(stride(scale_layout)); + const auto scale_stride2 = get<2>(stride(scale_layout)); + + auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches); + auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2); + auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast); + + const auto blocks_x = gemm_k; + const auto blocks_y = batches; + + dim3 blocks(blocks_x, blocks_y, 1); + dequantize_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout); + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +template +class packed_scale_t { +public: + static_assert(cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "only 8 bit arithmetic types are supported."); + CUTLASS_HOST_DEVICE + explicit packed_scale_t(T val) { + if constexpr (!cute::is_unsigned_v) { + // Only pack negative values. The positive values are generated in flight in the mainloop. + storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f)); + storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val); + } + else { + storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f)); + storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val); + } + } + CUTLASS_HOST_DEVICE + packed_scale_t() = default; + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + CUTLASS_HOST_DEVICE + bool operator==(packed_scale_t const& rhs) const { + return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1]; + } + CUTLASS_HOST_DEVICE + bool operator!=(packed_scale_t const& rhs) const { + return !(*this == rhs); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() + rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() - rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() * rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() / rhs.get()); + } + +private: + using Storage = uint32_t; + using Stage = uint8_t; + + Storage storage[2] {}; + + CUTLASS_HOST_DEVICE + static Storage pack4(T c1, T c2, T c3, T c4) { + Storage result = 0; + result |= (static_cast(reinterpret_cast(c4)) << 24); + result |= (static_cast(reinterpret_cast(c3)) << 16); + result |= (static_cast(reinterpret_cast(c2)) << 8); + result |= static_cast(reinterpret_cast(c1)); + return result; + } + CUTLASS_HOST_DEVICE + T get() const { + auto stage = static_cast(storage[0] >> 8); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } + CUTLASS_HOST_DEVICE + T get(int idx) const { + Stage stage; + if (idx < 4) stage = static_cast(storage[0] >> (8 * idx)); + else stage = static_cast(storage[1] >> (8 * idx - 32)); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } +}; + +// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. +// Here the encodings of positive values and negative values are unified (except for the sign bit). +// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). +static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) { + + using StorageType = cutlass::int4b_t::Storage; + constexpr int pack = cute::sizeof_bits_v / 4; + const size_t host_buf_size = block_size / pack; + std::vector host_buf(host_buf_size); + cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size); + + for (auto&& d : host_buf) { + StorageType out = 0; + StorageType mask = 0x0f; + for (int i = 0; i < pack; i++) { + cutlass::int4b_t curr; + curr.storage = (d >> (i * 4)) & 0x0f; + switch (curr) { + case 1: curr.storage = StorageType(0b0111); break; // 2's complement + case 2: curr.storage = StorageType(0b0110); break; // 2's complement + case 3: curr.storage = StorageType(0b0101); break; // 2's complement + case 4: curr.storage = StorageType(0b0100); break; // 2's complement + case 5: curr.storage = StorageType(0b0011); break; // 2's complement + case 6: curr.storage = StorageType(0b0010); break; // 2's complement + case 7: curr.storage = StorageType(0b0001); break; // 2's complement + default: break; + } + out |= (curr.storage << (4 * i)) & mask; + mask <<= 4; + } + d = out; + } + + cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size); + return true; +} + +template +static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array *block_out, const size_t block_size) { + std::vector data_in(block_size); + std::vector> data_out(block_size); + + try { + cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size); + } + catch (cutlass::cuda_exception const& e) { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + + for (size_t i = 0; i < block_size; i++) { + cutlass::packed_scale_t tmp(data_in[i]); + data_out[i] = reinterpret_cast const&>(tmp); + } + + try { + cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size); + } + catch (cutlass::cuda_exception const& e) { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + return true; +} + +template +struct UnderlyingElement { + using type = T; +}; + +template +struct UnderlyingElement> { + using type = typename T::Element; +}; + +// Given a type of MMA instruction, compute a memory reordering atom that places all values +// owned by each thread in contiguous memory locations. This improves smem load vectorization, +// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order +// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses. +// In addition, we can reorder the values across several MMA instructions to get even wider +// vectorization (AtomLayout parameter) and permute the values within each instruction to get +// more optimal conversion instruction sequences (ValLayout parameter). +template , + class ValLayout = cute::Layout> +constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {}) +{ + using namespace cute; + + static_assert(is_static_v, "ValLayout must be static"); + static_assert(is_static_v, "AtomLayout must be static"); + + // 1. Choose an MMA atom to access TV layout and MN shape + // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary + using MmaAtom = decltype(SM90::GMMA::rs_op_selector>()); + using MmaTraits = MMA_Traits; + auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{}); + auto tv_layout_mma = typename MmaTraits::ALayout{}; + static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout"); + + // 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val) + // Note: this assumes A is partitioned between warps along M mode + auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma)); + auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{}); + auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp)); + auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp); + + // 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization + auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout); + + // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization) + auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout)); + auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp)); + auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset)); + auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt); + + return layout_atom; +} + +template +__global__ void reorder_tensor_kernel( + cute::Tensor S, + cute::Tensor D, + TiledCopy tiled_copy) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + + Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + + auto thread_copy = tiled_copy.get_slice(threadIdx.x); + Tensor tS = thread_copy.partition_S(gS); + Tensor tD = thread_copy.partition_D(gD); + + copy(tiled_copy, tS, tD); +} + +template +void reorder_tensor( + cute::Tensor S, + cute::Tensor D) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + static_assert(is_same_v, T>, "Type mismatch"); + + // Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread + // This avoids a race condition when writing out subbyte types (e.g. int4b_t). + auto has_major_mode = [](auto s) { + return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; }); + }; + static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})), + "Could not find stride-1 mode in destination layout"); + constexpr int N = shape_div(Int<8>{}, Int>{}); + auto val_layout = conditional_return(LayoutDst{}))>( + make_layout(make_shape(Int{}, Int<1>{}), GenColMajor{}), + make_layout(make_shape(Int<1>{}, Int{}), GenRowMajor{})); + + // Make a tiled copy with a simple row-major thread order and above layout + int constexpr NumThreads = 128; + auto const thr_layout = make_layout(make_shape(Int<1>{}, Int{})); + auto tiled_copy = make_tiled_copy(Copy_Atom{}, thr_layout, val_layout); + + // Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper + using TileShape = Shape<_16>; + auto tiled_D = group_modes<3,rank_v>(tiled_divide(D, TileShape{})); + dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))}; + + reorder_tensor_kernel<<>>(S, D, tiled_copy); + CUDA_CHECK(cudaDeviceSynchronize()); +} + +// In-place version +template +void reorder_tensor( + T const* src, + LayoutSrc const& layout_src, + T * dst, + LayoutDst const& layout_dst) +{ + using namespace cute; + reorder_tensor(make_tensor(make_gmem_ptr(src), layout_src), + make_tensor(make_gmem_ptr(dst), layout_dst)); +} + +// In-place version +template +void reorder_tensor( + T * data, + LayoutSrc const& layout_src, + LayoutDst const& layout_dst) +{ + using namespace cute; + cutlass::DeviceAllocation temp(size(layout_src)); + reorder_tensor(data, layout_src, temp.get(), layout_dst); + cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); +} + +#undef CUDA_CHECK + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp new file mode 100644 index 0000000000000000000000000000000000000000..811ba152ab7c6e8fafc1cebdbb3726798fd16b3c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp @@ -0,0 +1,570 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params. +*/ + +#pragma once + +#include "cute/layout.hpp" +#include "cute/container/array.hpp" // cute::array +#include "cutlass/conv/convolution.h" // cutlass::conv::Operator + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides without batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride> +make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride(cute::Stride, IntT> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, int64_t> +make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, int64_t> +make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with group mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride(cute::Stride, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, StrideIntT, cute::Int<0>> +make_cute_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides for convolutions + +// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0) +// Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order +// and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout +// right in KTRSC order and can be coalesced to just k. +// We enforce this condition here with asserts. +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, cute::Int<0>> s, + cute::array shape_output, + cute::array stride_output, + cutlass::conv::Operator conv_op) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + static_assert(RankT_ >= 3u); + constexpr static int RankT = static_cast(RankT_); + + assert(stride_output[RankT-1] == 1); + cute::for_each(cute::make_seq{}, [&](auto i) { + assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]); + }); + + auto s_copy = s; + cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ? + stride_output[0] : + stride_output[RankT-2]; + return s_copy; +} + +// +// Activation tensor ((w, h, d, n), _1) for fprop kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nwc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_nwc[1]; + cute::get<0,1>(s_copy) = stride_nwc[0]; + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nhwc[3] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_nhwc[2-i]; + }); + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_ndhwc[3-i]; + }); + return s_copy; +} + +// +// Filter tensor (k, (_1, s, r, t)) for fprop kernel +// + +// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>> +make_cute_packed_stride( + cute::Stride, IntT>> s, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>> s, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>> s, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + +// +// Activation tensor (_1, (w, h, d, n)) for wgrad kernel +// +// It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad +// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nwc[2] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::get<1,0>(s_copy) = stride_nwc[1]; + cute::get<1,1>(s_copy) = stride_nwc[0]; + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nwc in dgrad is ksc. + cute::get<1,0>(s_copy) = stride_nwc[0]; + cute::get<1,1>(s_copy) = stride_nwc[1]; + } + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad +// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nhwc[3] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_nhwc[2-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nhwc in dgrad is krsc. + cute::get<1,0>(s_copy) = stride_nhwc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_nhwc[i+1]; + }); + } + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_ndhwc[3-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_ndhwc in dgrad is ktrsc. + cute::get<1,0>(s_copy) = stride_ndhwc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1]; + }); + } + return s_copy; +} + +// +// NZPQ tensor (_1, nzpq) for wgrad kernel +// + +// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nqk[2] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nqk[1]; + return s_copy; +} + +// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_npqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_npqk[3] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_npqk[2]; + return s_copy; +} + +// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nzpqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nzpqk[4] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nzpqk[3]; + return s_copy; +} + + + +// +// Wgrad output tensor (k, (_1, s, r, t), _0) +// + +// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + + +// +// Wgrad output tensor ((_1, s, r, t), k, _0) +// + +// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ksc[0]; + cute::get<0,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<0,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c38ad3f710c18e5be1bb7e01dc66d7efcd2646d9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp @@ -0,0 +1,341 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include + +// The computed infinity norm does not include +// any NaN column absolute-value sums. +struct matrix_inf_norm_result { + // Accumulate errors in double, as this is generally + // the highest precision that the examples use. + double inf_norm = 0.0; + bool found_nan = false; +}; + +// In theory, cute::Tensor, T> could be treated as a view type, +// and thus passed by value (as std::span or std::string_view would be). +// However, generic cute::Tensor are more like containers +// and thus are best passed by reference or const reference. +template +matrix_inf_norm_result +matrix_inf_norm(cute::Tensor const& host_matrix) +{ + using error_type = decltype(std::declval().inf_norm); + using element_type = typename EngineType::value_type; + + error_type inf_norm = 0.0; + bool found_nan = false; + + // Computing the infinity norm requires that we be able + // to treat the input as a matrix, with rows and columns. + const int64_t num_rows = cute::size<0>(host_matrix); + const int64_t num_cols = cute::size<1>(host_matrix); + + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + + for (int64_t i = 0; i < num_rows; ++i) { + error_type row_abs_sum = 0.0; + for(int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += abs_fn(host_matrix(i, j)); + } + if (std::isnan(row_abs_sum)) { + found_nan = true; + } + else { + inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; + } + } + + return {inf_norm, found_nan}; +} + +// Infinity norm of (X - Y). +template +matrix_inf_norm_result +matrix_diff_inf_norm(cute::Tensor const& X, + cute::Tensor const& Y) +{ + using error_type = decltype(std::declval().inf_norm); + using element_type = typename EngineType::value_type; + + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + + assert(cute::size<0>(X) == cute::size<0>(Y)); + assert(cute::size<1>(X) == cute::size<1>(Y)); + + // Computing the infinity norm requires that we be able + // to treat the input as a matrix, with rows and columns. + const int64_t num_rows = cute::size<0>(X); + const int64_t num_cols = cute::size<1>(X); + + error_type inf_norm = 0.0; + bool found_nan = false; + + for (int64_t i = 0; i < num_rows; ++i) { + error_type row_abs_sum = 0.0; + for (int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += error_type(abs_fn(element_type(X(i,j)) - + element_type(Y(i,j)))); + } + if (std::isnan(row_abs_sum)) { + found_nan = true; + } + else { + inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; + } + } + + return {inf_norm, found_nan}; +} + +template +auto +print_matrix_multiply_mollified_relative_error( + char const A_value_type_name[], + cute::Tensor const& A, + char const B_value_type_name[], + cute::Tensor const& B, + char const C_value_type_name[], + cute::Tensor const& C, + cute::Tensor const& C_ref) +{ + const auto [A_norm, A_has_nan] = matrix_inf_norm(A); + const auto [B_norm, B_has_nan] = matrix_inf_norm(B); + const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref); + const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref); + + const auto A_norm_times_B_norm = A_norm * B_norm; + const auto relative_error = A_norm_times_B_norm == 0.0 ? + diff_norm : (diff_norm / A_norm_times_B_norm); + + // For expected error bounds, please refer to the LAPACK Users' Guide, + // in particular https://netlib.org/lapack/lug/node108.html . + // Printing the infinity norm of C is a way to check + // that both the function being tested (C) + // and the reference implementation (C_ref) + // don't just do nothing (or fill with zeros). + using std::cout; + using cute::shape; + cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n' + << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n' + << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n' + << std::scientific + << "Infinity norm of A: " << A_norm << '\n' + << "Infinity norm of B: " << B_norm << '\n' + << "Infinity norm of C: " << C_norm << '\n' + << "Infinity norm of (C - C_ref): " << diff_norm << '\n'; + + if(A_norm_times_B_norm == 0.0) { + cout << "Mollified relative error: " << relative_error << '\n'; + } else { + cout << "Relative error: " << relative_error << '\n'; + } + + if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) { + cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n'; + } + return relative_error; +} + +template +auto +print_matrix_multiply_mollified_relative_error( + const char value_type_name[], + const cute::Tensor& A, + const cute::Tensor& B, + const cute::Tensor& C_computed, + const cute::Tensor& C_expected) +{ + return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, + value_type_name, C_computed, C_expected); +} + +// Take a CUTLASS HostTensor (or the like) as input, +// and return a const CuTe Tensor. +// This is useful for use with the above error printing functions. +// This implicitly "transposes" if the layout is RowMajor. +// Note that the HostTensor must be captured by nonconst reference +// in order for X.host_ref().data() to compile. +// (CUTLASS is a bit more container-y than CuTe.) +template +auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X) +{ + // The tensors were created with post-transposed extents. + const auto extents = X.extent(); + const auto shape = cute::Shape{extents[0], extents[1]}; + // Both RowMajor and ColumnMajor only store one stride. + const int LDX = X.stride(0); + const auto strides = [&]() { + using input_layout_type = typename std::decay_t::Layout; + if constexpr (std::is_same_v) { + return cute::Stride{1, LDX}; + } + else { + static_assert(std::is_same_v); + return cute::Stride{LDX, 1}; + } + }(); + const auto layout = cute::make_layout(shape, strides); + auto X_data = X.host_ref().data(); + auto X_data_const = const_cast >(X_data); + return cute::make_tensor(X_data_const, layout); +}; + + +// Returns EXIT_SUCCESS if the 2-norm relative error is exactly zero, else returns EXIT_FAILURE. +// This makes the return value suitable as the return value of main(). +template +int +print_relative_error( + std::size_t n, + T1 const& data, + T2 const& reference, + bool print_verbose = false, + bool print_error = true, + double error_margin = 0.00001) { + using std::abs; using std::sqrt; + + // Use either double or complex for error computation + using value_type = cute::remove_cvref_t; + using error_type = std::conditional_t::value, + cute::complex, + double>; + + if (print_verbose) { + std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl; + } + + double eps = 1e-200; + + double tot_error_sq = 0; + double tot_norm_sq = 0; + double tot_ind_rel_err = 0; + double max_ind_rel_err = 0; + double max_diff = 0; + for (std::size_t i = 0; i < n; ++i) { + error_type val = data[i]; + error_type ref = reference[i]; + + double aref = abs(ref); + double diff = abs(ref - val); + double rel_error = diff / (aref + eps); + + // Individual relative error + tot_ind_rel_err += rel_error; + + // Maximum relative error + max_ind_rel_err = std::max(max_ind_rel_err, rel_error); + + // Maximum delta in value error + max_diff = std::max(max_diff, diff); + + // Total relative error + tot_error_sq += diff * diff; + tot_norm_sq += aref * aref; + + if (print_verbose) { + std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl; + } + } + + double ave_rel_err = tot_ind_rel_err / double(n); + if (print_error) { + printf("Average relative error: %.3e\n", ave_rel_err); + } + + if (print_error) { + printf("Maximum relative error: %.3e\n", max_ind_rel_err); + } + + if (print_error) { + printf("Maximum difference : %.3e\n", max_diff); + } + + double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps)); + if (print_error) { + printf("Vector relative error: %.3e\n", tot_rel_err); + } + + printf("Vector reference norm: %.3e\n", sqrt(tot_norm_sq)); + + return (tot_rel_err <= error_margin) ? EXIT_SUCCESS : EXIT_FAILURE; +} + +// Overload for cute::Tensor<> +template +int +print_relative_error( + cute::Tensor data, + cute::Tensor reference, + bool print_verbose = false, + bool print_error = true, + double error_margin = 0.00001) { + assert(size(data) == size(reference)); + return print_relative_error(static_cast(size(data)), + data, reference, + print_verbose, print_error, error_margin); +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h new file mode 100644 index 0000000000000000000000000000000000000000..8167c91bf2330d160a78ba210449357b395964ca --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +namespace cutlass { +namespace reference { +namespace detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template function to compute an inner product. +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a + // host-only type +template +CUTLASS_HOST_DEVICE +Ctype inner_product(Atype a, Btype b, Ctype c) { + return Ctype(a) * Ctype(b) + c; +} + +/// Specialization for matrix multiplication with binary operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int bit = 0; bit < 32; bit++) { + accum += a[bit] ^ b[bit]; + } + return accum + c; +} + +/* +/// Specialization for matrix multiplication with signed 4-bit integer operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int k = 0; k < 8; k++) { + accum += a[k] * b[k]; + } + return accum + c; +} + +/// Specialization for matrix multiplication with unsigned 4-bit integer operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int k = 0; k < 8; k++) { + accum += a[k] * b[k]; + } + return accum + c; +} +*/ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Cast { + // Default behavior: convert to the destination type +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + static DstType apply(SrcType src) { return static_cast(src); }; +}; + +template <> +struct Cast { + CUTLASS_HOST_DEVICE + static int8_t apply(float src) { + // Clamp to the range of signed 8-bit integers. + return static_cast(fmaxf(-128.f, fminf(127.f, src))); + }; +}; + +template <> +struct Cast { + CUTLASS_HOST_DEVICE + static uint8_t apply(float src) { + // Clamp to the range of signed 8-bit integers. + return static_cast(fmaxf(0.f, fminf(255.f, src))); + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h new file mode 100644 index 0000000000000000000000000000000000000000..652d622586cb202ecfe69ac892978b649b5d1be7 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + + int64_t prod = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - Index; i < Rank; ++i) { + prod *= int64_t(extent[i]); + } + + coord[Rank - Index - 1] = int(idx / prod); + + int64_t residual = idx % prod; + LinearToCoordinateHelper()(coord, residual, extent); + } +}; + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &) const { + coord[Rank - 1] = int(idx); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinate { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + LinearToCoordinateHelper()(coord, idx, extent); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..7c6f803c47f5c407cf058d40bc8274a448a36dc4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h @@ -0,0 +1,1549 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Reference implementation for convolution in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv2d device reference kernel +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d Fprop kernel - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t npq = npq_start + m; + + thread_n[m] = int(npq / PQ); + + int64_t residual = npq % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + int c_per_group = problem_size.C / problem_size.groups; + int k_per_group = problem_size.K / problem_size.groups; + + // Compute convolution + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int C = 0; C < problem_size.C; ++C) { + + // Get group id of currnet channel + int c_group_idx = C / c_per_group; + + // Load from activations tensor + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { + element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C})); + } + else { + element_A[m] = ElementAccumulator(); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + int k_group_idx = thread_k / k_per_group; + + if (thread_k < problem_size.K && k_group_idx == c_group_idx) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})); + } + + tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d Fprop kernel - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_z[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, Z, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + int64_t ZPQ = PQ * problem_size.Z; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t nzpq = nzpq_start + m; + + thread_n[m] = int(nzpq / ZPQ); + + int64_t residual = nzpq % ZPQ; + thread_z[m] = int(residual / PQ); + + residual = residual % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int T = 0; T < problem_size.T; ++T) { + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int C = 0; C < problem_size.C; ++C) { + + // Load from activations tensor + int filter_t = T; + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - T; + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (thread_n[m] < problem_size.N && + d >= 0 && d < problem_size.D && + h >= 0 && h < problem_size.H && + w >= 0 && w < problem_size.W) { + + element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C})); + } + else { + element_A[m] = ElementAccumulator(); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + + if (thread_k < problem_size.K) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && + thread_z[m] < problem_size.Z && + thread_p[m] < problem_size.P && + thread_q[m] < problem_size.Q) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k})); + } + + tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } // for (n) + + } + } // for (m) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d dgrad kernel - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv2dDgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_h[kThreadM]; + int thread_w[kThreadM]; + + // Compute N, H, W coordinates for each row of a thread's tile + int64_t HW = int64_t(problem_size.H) * problem_size.W; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t nhw = nhw_start + m; + + thread_n[m] = int(nhw / HW); + + int64_t residual = nhw % HW; + thread_h[m] = int(residual / problem_size.W); + thread_w[m] = int(residual % problem_size.W); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int K = 0; K < problem_size.K; ++K) { + + // Load from activations tensor + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; + + element_A[m] = ElementAccumulator(); + + if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) { + + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) { + element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K})); + } + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + + if (thread_c < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + if (thread_c < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c})); + } + + tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d dgrad kernel - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv3dDgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_d[kThreadM]; + int thread_h[kThreadM]; + int thread_w[kThreadM]; + + // Compute N, H, W coordinates for each row of a thread's tile + int64_t HW = int64_t(problem_size.H) * problem_size.W; + int64_t DHW = HW * problem_size.D; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t ndhw = ndhw_start + m; + + thread_n[m] = int(ndhw / DHW); + + int64_t residual = ndhw % DHW; + thread_d[m] = int(residual / HW); + + residual = residual % HW; + thread_h[m] = int(residual / problem_size.W); + thread_w[m] = int(residual % problem_size.W); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int T = 0; T < problem_size.T; ++T) { + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int K = 0; K < problem_size.K; ++K) { + + // Load from activations tensor + int filter_t = T; + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - T; + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d; + int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; + + element_A[m] = ElementAccumulator(); + + if (z >= 0 && !(z % problem_size.stride_d) && + p >= 0 && !(p % problem_size.stride_h) && + q >= 0 && !(q % problem_size.stride_w)) { + + z = z / problem_size.stride_d; + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { + element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K})); + } + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + + if (thread_c < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && + thread_d[m] < problem_size.D && + thread_h[m] < problem_size.H && + thread_w[m] < problem_size.W) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + if (thread_c < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c})); + } + + tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d wgrad kernel - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 8, // shape of a threadblock in units of threads + int kCtaShapeN = 16 // shape of a threadblock in units of threads +> +__global__ void Conv2dWgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_r[kThreadN]; + int thread_s[kThreadN]; + int thread_c[kThreadN]; + + // Compute R, S, C coordinates for each row of a thread's tile + int64_t SC = int64_t(problem_size.S) * problem_size.C; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + int64_t rsc = rsc_start + n; + int64_t residual = rsc % SC; + + thread_r[n] = int(rsc / SC); + thread_s[n] = int(residual / problem_size.C); + thread_c[n] = int(residual % problem_size.C); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int N = 0; N < problem_size.N; ++N) { + for (int P = 0; P < problem_size.P; ++P) { + for (int Q = 0; Q < problem_size.Q; ++Q) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + element_A[m] = ElementAccumulator(); + + if (thread_k < problem_size.K) { + element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k})); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + // Load from activations tensor + int filter_r = thread_r[n]; + int filter_s = thread_s[n]; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - filter_r; + filter_s = problem_size.S - 1 - filter_s; + } + + int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + element_B[n] = ElementAccumulator(); + + if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]})); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + if (thread_k < problem_size.K) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]})); + } + + tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d wgrad kernel - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 8, // shape of a threadblock in units of threads + int kCtaShapeN = 16 // shape of a threadblock in units of threads +> +__global__ void Conv3dWgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_t[kThreadN]; + int thread_r[kThreadN]; + int thread_s[kThreadN]; + int thread_c[kThreadN]; + + // Compute R, S, C coordinates for each row of a thread's tile + int64_t SC = int64_t(problem_size.S) * problem_size.C; + int64_t RSC = SC * problem_size.R; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + int64_t trsc = trsc_start + n; + + thread_t[n] = int(trsc / RSC); + + int64_t residual = trsc % RSC; + thread_r[n] = int(residual / SC); + + residual = residual % SC; + thread_s[n] = int(residual / problem_size.C); + thread_c[n] = int(residual % problem_size.C); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int N = 0; N < problem_size.N; ++N) { + for (int Z = 0; Z < problem_size.Z; ++Z) { + for (int P = 0; P < problem_size.P; ++P) { + for (int Q = 0; Q < problem_size.Q; ++Q) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + element_A[m] = ElementAccumulator(); + + if (thread_k < problem_size.K) { + element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k})); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + // Load from activations tensor + int filter_t = thread_t[n]; + int filter_r = thread_r[n]; + int filter_s = thread_s[n]; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - filter_t; + filter_r = problem_size.R - 1 - filter_r; + filter_s = problem_size.S - 1 - filter_s; + } + + int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + element_B[n] = ElementAccumulator(); + + if (d >= 0 && d < problem_size.D && + h >= 0 && h < problem_size.H && + w >= 0 && w < problem_size.W && + thread_c[n] < problem_size.C) { + + element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]})); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (Q) + } // for (P) + } // for (Z) + } // for (N) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + if (thread_k < problem_size.K) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + if (thread_t[n] < problem_size.T && + thread_r[n] < problem_size.R && + thread_s[n] < problem_size.S && + thread_c[n] < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]})); + } + + tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Conv2d Fprop dispatcher - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; + int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv2dFprop< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_x, + tensor_w, + tensor_y_in, + tensor_y_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Fprop dispatcher - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q; + int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv3dFprop< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_x, + tensor_w, + tensor_y_in, + tensor_y_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dDgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W; + int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv2dDgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_w, + tensor_dx_in, + tensor_dx_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dDgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W; + int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv3dDgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_w, + tensor_dx_in, + tensor_dx_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dWgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 8; // shape of a threadblock in units of threads + int const kCtaShapeN = 16; // shape of a threadblock in units of threads + + int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C; + int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); + + kernel::Conv2dWgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_x, + tensor_dw_in, + tensor_dw_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dWgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 8; // shape of a threadblock in units of threads + int const kCtaShapeN = 16; // shape of a threadblock in units of threads + + int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C; + int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); + + kernel::Conv3dWgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_x, + tensor_dw_in, + tensor_dw_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2d( + conv::Operator convolutional_operator, + conv::Conv2dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + return Conv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + case conv::Operator::kDgrad: + return Conv2dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + case conv::Operator::kWgrad: + return Conv2dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + default: break; + } + + return Status::kErrorNotSupported; +} + +/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3d( + conv::Operator convolutional_operator, + conv::Conv3dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + return Conv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + case conv::Operator::kDgrad: + return Conv3dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + case conv::Operator::kWgrad: + return Conv3dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + default: break; + } + + return Status::kErrorNotSupported; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..7d575d522c1dd87d51f9bc58d09786393c5cfea3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h @@ -0,0 +1,385 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/device/kernel/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + // Launch a GEMM kernel + kernel::Gemm< + TensorRef, + TensorRef, + TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum) { + + compute_gemm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Gemm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add-saturate +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for XOR-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Batched GEMM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a batch of GEMMs over a set of matrices of common dimension. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp, + typename ConvertOp +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c, + AccumulatorType initial_accum) { + + static_assert( + TensorRefCollectionA::kRank == 2 && + TensorRefCollectionB::kRank == 2 && + TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2"); + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn), + batch_count + ); + + // Launch a GEMM kernel + kernel::BatchedGemm< + TensorRefCollectionA, + TensorRefCollectionB, + TensorRefCollectionC, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + initial_accum + ); +} + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c) { + + BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..bddf596214da62a7aa3177f758db3710dc1d2516 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + if (grid.y <= std::numeric_limits::max()) { + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ElementD, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); + } else { + // Using bigger thread tile size + int const kBigMblock = 4; + int const kBigNblock = 16; + + dim3 Bigblock(16, 8); + dim3 Biggrid( + (problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock), + (problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ElementD, + ConvertOp, + InnerProductOp, + kBigMblock, + kBigNblock + ><<< Biggrid, Bigblock >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ElementD = ElementC +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..48819cf6eaa565b3ec41dbbf78ae244666fd8a65 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static int const kGemmPlanarComplexBlockSize = 4; + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +__global__ void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + int const kMblock = kGemmPlanarComplexBlockSize; + int const kNblock = kGemmPlanarComplexBlockSize; + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + complex accum[kMblock][kNblock]; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block < K; ++k_block) { + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC c_ij = ComplexC(); + + if (beta.real() != ScalarType() || beta.imag() != ScalarType()) { + c_ij = tensor_c.at(coord); + } + + complex src{ + ScalarType(c_ij.real()), + ScalarType(c_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + ComplexC d_ij; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag()); + + tensor_d.at(coord) = d_ij; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = kernel::kGemmPlanarComplexBlockSize; + int const kNblock = kernel::kGemmPlanarComplexBlockSize; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + 1); + + kernel::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp new file mode 100644 index 0000000000000000000000000000000000000000..497a257d170c411d891942f62fa2c960453d03d5 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT device reference code +*/ +#pragma once + +#include + +namespace cutlass::reference::device { + +template < + class ATensor, + class BTensor, + class CTensor, + class DTensor, + class ElementAccumulator, + class ElementEpilogue> +__global__ static +void +gett_kernel( + DTensor D, + ATensor const A, + BTensor const B, + CTensor const C, + ElementEpilogue alpha, ElementEpilogue beta, + ElementAccumulator acc_init) +{ + using namespace cute; + + static_assert(DTensor::rank == 3, "(M,N,L)"); + static_assert(ATensor::rank == 3, "(M,K,L)"); + static_assert(BTensor::rank == 3, "(N,K,L)"); + static_assert(CTensor::rank == 3, "(M,N,L)"); + + assert(size<0>(A) == size<0>(D)); // M + assert(size<0>(C) == size<0>(D)); // M + assert(size<0>(B) == size<1>(D)); // N + assert(size<1>(C) == size<1>(D)); // N + assert(size<1>(A) == size<1>(B)); // K + assert(size<2>(A) == size<2>(D)); // L + assert(size<2>(B) == size<2>(D)); // L + assert(size<2>(C) == size<2>(D)); // L + + NumericConverter a_converter; + NumericConverter b_converter; + NumericConverter acc_converter; + NumericConverter source_converter; + NumericConverter output_converter; + + // Thread id to each element of D + for (int tid = threadIdx.x + blockDim.x * blockIdx.x; + tid < size(D); + tid += blockDim.x * gridDim.x) { + // (m,n,l) coordinate + auto mnl_coord = idx2crd(tid, product_each(shape(D))); + auto m = get<0>(mnl_coord); + auto n = get<1>(mnl_coord); + auto l = get<2>(mnl_coord); + + auto A_ml = A(m,_,l); + auto B_nl = B(n,_,l); + + ElementAccumulator accum = ElementAccumulator(0); + for (int k = 0; k < size<1>(A); ++k) { + ElementAccumulator a = a_converter(A_ml(k)); + ElementAccumulator b = b_converter(B_nl(k)); + accum += a * b; + } + + ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l))); + D(m,n,l) = output_converter(scaled_output); + } +} + +// Most general version +template < + class ProblemShapeMNKL, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class ElementAccumulator, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class ElementEpilogue> +void +gett( + ProblemShapeMNKL problem_shape_mnkl, + ElementA const* ptr_A, StrideA stride_a_mkl, + ElementB const* ptr_B, StrideB stride_b_nkl, + ElementAccumulator _, + ElementC const* ptr_C, StrideC stride_c_mnl, + ElementD * ptr_D, StrideD stride_d_mnl, + ElementEpilogue alpha, ElementEpilogue beta, + cudaStream_t stream = 0) { + using namespace cute; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4); + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto K = get<2>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full tensors + auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L) + auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L) + auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L) + auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L) + + dim3 dimBlock(256); + dim3 dimGrid(240); + gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0)); +} + +} // namespace cutlass::reference::device diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..6e131126a336420a2b0e843e3ead3d89fce637fa --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/device/thread/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename TensorRefA, + typename TensorRefB, + typename TensorRefC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp, + typename ConvertOp +> +__global__ void Gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRefA tensor_a, + TensorRefB tensor_b, + ScalarType beta, + TensorRefC tensor_c, + TensorRefC tensor_d, + AccumulatorType initial_accum) { + + // Map each thread to a unique tile of the output matrix + MatrixCoord output_coord( + MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), + MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) + ); + + // Compute the general matrix product + thread::Gemm< + TensorRefA, + TensorRefB, + TensorRefC, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + > gemm(initial_accum); + + gemm.multiply_add( + problem_size, + tensor_a, + tensor_b, + output_coord); + + gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp, + typename ConvertOp +> +__global__ void BatchedGemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRefCollectionA tensor_collection_a, + TensorRefCollectionB tensor_collection_b, + ScalarType beta, + TensorRefCollectionC tensor_collection_c, + AccumulatorType initial_accum) { + + // Obtain batch ID + int batch_id = blockIdx.z; + + // Dereference based on batch_id + typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id); + typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id); + typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id); + + // Map each thread to a unique tile of the output matrix + MatrixCoord output_coord( + (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn, + (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow + ); + + // Compute the general matrix product + thread::Gemm< + typename TensorRefCollectionA::TensorRef, + typename TensorRefCollectionB::TensorRef, + typename TensorRefCollectionC::TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + > gemm(initial_accum); + + gemm.multiply_add( + problem_size, + tensor_a, + tensor_b, + output_coord); + + gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h new file mode 100644 index 0000000000000000000000000000000000000000..149e4b2e00e2ac8130cee9dc189a539ba3a70297 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to initialize tensor to uniform random distribution +template +__global__ void TensorInitializeUniform( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + double range = dist.uniform.max - dist.uniform.min; + + double rnd = curand_uniform(&rng_state[threadIdx.x]); + + rnd = dist.uniform.min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + if (dist.int_scale >= 0) { + rnd = double(int(rnd * double(1 << dist.int_scale))); + *tensor = T(rnd / double(1 << dist.int_scale)); + } else { + *tensor = T(rnd); + } + + tensor += ldm; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to initialize tensor to uniform distribution +template +__global__ void TensorInitializeGaussian( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + double rnd = curand_normal(&rng_state[threadIdx.x]); + + rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd; + + if (dist.int_scale >= 0) { + rnd = double(int(rnd * double(1 << dist.int_scale))); + *tensor = T(rnd / double(1 << dist.int_scale)); + } else { + *tensor = T(rnd); + } + } + } +} + +/// Kernel to initialize tensor to an identity matrix +template +__global__ void TensorInitializeLinear( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + *tensor = + dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx; + } + } +} + +/// Kernel to initialize tensor to an identity matrix +template +__global__ void TensorInitializeIdentity( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + *tensor = (c_idx == s_idx ? T(1) : T(0)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..3223cb2056ba6d88f47f7b117392a56e325d0ce7 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/fast_math.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Constructor for general rank + __inline__ __device__ + TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { + + int64_t product = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - RankRemaining; i < Rank; ++i) { + product *= size[i]; + } + + coord[Rank - 1 - RankRemaining] = index / product; + int64_t remaining = index % product; + + TensorForEachHelper(func, size, coord, remaining); + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Constructor for fastest changing rank + __inline__ __device__ + TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { + + coord[Rank - 1] = index; + + if (coord < size) { + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel calls a functor for each element in a tensor's index space +template +__global__ void TensorForEach(Coord size, Params params = Params()) { + + Func func(params); + + int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + int64_t max_index = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + max_index *= size[i]; + } + + CUTLASS_PRAGMA_NO_UNROLL + while (index < max_index) { + Coord coord; + + detail::TensorForEachHelper(func, size, coord, index); + index += blockDim.x * gridDim.x; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel calls a functor for each element along a tensor's diagonal +template +__global__ void TensorDiagonalForEach(Coord size, Params params, int start, int end) { + + Func func(params); + + int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start; + + if (index < end) { + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] = index; + } + + func(coord); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params) { + + Func func(params); + + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + + for (; index < capacity; index += blockDim.x * gridDim.x) { + ReferenceFactory::get(ptr, index) = func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..2e76fe52b06f9bb1a033c736f94fa01961ce664d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h @@ -0,0 +1,355 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + assert(M=N); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x B^T (Symmetric) or A x B^H (Hermitian) + // complex conjugation on operandB (b_t) is function of blas3 computation + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_b.at(MatrixCoord(col, k_block))) : + tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(b_t); + + // complex conjugation is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + + // B x A^T (Symmetric) or B x A^H (Hermitian) + // complex conjugation on operandB (a_t) is function of blas3 computation + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))): + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType b_ik = ComputeType(b); + ComputeType a_jk = ComputeType(a_t); + + // complex conjugation here is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_ik = conj(b_ik); + } + // complex conjugation here is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_jk = conj(a_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::Rank2KComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + fill_mode_c, + blas_mode, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h new file mode 100644 index 0000000000000000000000000000000000000000..1999730f6d24e69aef152aa332fae68af57a9c40 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" + +#include "cutlass/util/distribution.h" + +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template +__global__ void BlockCompareEqual( + int *equal, + Element const *ptr_A, + Element const *ptr_B, + size_t capacity) { + + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + for (; idx < capacity; idx += gridDim.x * blockDim.x) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (a != b) { + *equal = 0; + + return; + } + } +} + +template +__global__ void BlockCompareRelativelyEqual( + int *equal, + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + Element epsilon, + Element nonzero_floor) { + + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + for (; idx < capacity; idx += gridDim.x * blockDim.x) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (!relatively_equal(a, b, epsilon, nonzero_floor)) { + *equal = 0; + return; + } + } +} + +} // namespace kernel + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performs a bit-level equality check between two blocks +template +bool BlockCompareEqual( + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + int equal_flag = 1; + int *device_equal_flag = nullptr; + + if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { + throw std::runtime_error("Failed to allocate device flag."); + } + + if (cudaMemcpy( + device_equal_flag, + &equal_flag, + sizeof(int), + cudaMemcpyHostToDevice) != cudaSuccess) { + + throw std::runtime_error("Failed to copy equality flag to device."); + } + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockCompareEqual)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockCompareEqual<<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity); + + cudaStreamSynchronize(stream); + + if (cudaMemcpy( + &equal_flag, + device_equal_flag, + sizeof(int), + cudaMemcpyDeviceToHost) != cudaSuccess) { + + cudaFree(device_equal_flag); + + throw std::runtime_error("Failed to copy equality flag from device."); + } + + cudaFree(device_equal_flag); + + return equal_flag; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performs a bit-level equality check between two blocks +template +bool BlockCompareRelativelyEqual( + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + Element epsilon, + Element nonzero_floor, + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + int equal_flag = 1; + int *device_equal_flag = nullptr; + + if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { + throw std::runtime_error("Failed to allocate device flag."); + } + + if (cudaMemcpy( + device_equal_flag, + &equal_flag, + sizeof(int), + cudaMemcpyHostToDevice) != cudaSuccess) { + + throw std::runtime_error("Failed to copy equality flag to device."); + } + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockCompareRelativelyEqual)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockCompareRelativelyEqual<<< grid, block, 0, stream >>>( + device_equal_flag, + ptr_A, + ptr_B, + capacity, + epsilon, + nonzero_floor + ); + + cudaStreamSynchronize(stream); + + if (cudaMemcpy( + &equal_flag, + device_equal_flag, + sizeof(int), + cudaMemcpyDeviceToHost) != cudaSuccess) { + + cudaFree(device_equal_flag); + + throw std::runtime_error("Failed to copy equality flag from device."); + } + + cudaFree(device_equal_flag); + + return equal_flag; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // reference +} // cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h new file mode 100644 index 0000000000000000000000000000000000000000..a19b42825f6efb4a39466fe1cfc182ab7d831079 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -0,0 +1,2075 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) + +// Standard Library includes +#include +#include +#include +#include +#include + +#endif + +// CUDA includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_view.h" +#include "cutlass/blas3.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/layout/vector.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" +#include "cutlass/util/distribution.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +CUTLASS_DEVICE +FloatType random_normal_float(curandState_t *state) { + return curand_normal(state); +} + +template <> +CUTLASS_DEVICE +double random_normal_float(curandState_t *state) { + return curand_normal_double(state); +} + +template +CUTLASS_DEVICE +FloatType random_uniform_float(curandState_t *state) { + return curand_uniform(state); +} + +template <> +CUTLASS_DEVICE +double random_uniform_float(curandState_t *state) { + return curand_uniform_double(state); +} + +template +struct RandomGaussianFunc { + + using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element mean_ = 0, + Element stddev_ = 1, + int int_scale_ = -1, + int exclude_zero_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd = random_normal_float(&rng_state); + rnd = params.mean + params.stddev * rnd; + + Element result; + if (params.int_scale >= 0) { + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); + } + else { + result = Element(rnd); + } + + if (params.exclude_zero >=0 && result == Element(0.0)) { + if (rnd > FloatType(0)) { + rnd += FloatType(1); + } else { + rnd -= FloatType(1); + } + result = Element(rnd); + } + + return result; + } +}; + + +template +struct RandomGaussianFunc> { + + using Element = complex; + using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Real mean_ = 0, + Real stddev_ = 1, + int int_scale_ = -1, + int exclude_zero_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd_r = random_normal_float(&rng_state); + FloatType rnd_i = random_normal_float(&rng_state); + rnd_r = params.mean + params.stddev * rnd_r; + rnd_i = params.mean + params.stddev * rnd_i; + + Element result; + if (params.int_scale >= 0) { + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); + + result = { + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + if (params.exclude_zero >= 0 && + result.real() == Real(0.0) && + result.imag() == Real(0.0)) { + + if (rnd_r > FloatType(0)) { + rnd_r += FloatType(1); + } else { + rnd_r -= FloatType(1); + } + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomGaussianFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomGaussianFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = typename RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomGaussianFunc(Params const ¶ms): params(params), random(params.random) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + typename RealType::Type mean = Element(0), ///< Gaussian distribution's mean + typename RealType::Type stddev = Element(1), ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomGaussianFunc; + using Func = detail::TensorFillRandomGaussianFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, typename RandomFunc::Params(seed, mean, stddev, bits, exclude_zero)), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template ///< Element type +void BlockFillRandomGaussian( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type mean, ///< Gaussian distribution's mean + typename RealType::Type stddev, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomGaussianFunc; + + typename RandomFunc::Params params(seed, mean, stddev, bits); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random uniform distribution +template ///< Element type +struct RandomUniformFunc { + + using FloatType = typename std::conditional< + (sizeof(Element) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Element) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType max; + int int_scale; + double pnan; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element max_ = 1, + Element min = 0, + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + range(static_cast(max_) - static_cast(min)), + max(static_cast(max_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero >= 0) { + range = (min == Element(0)) ? range - FloatType(1): range; + max = (max_ == Element(0)) ? max - FloatType(1): max; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(NAN); + } + } + + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.max - params.range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); + } + else { + result = Element(rnd); + } + + if (params.exclude_zero >=0 && result == Element(0.0)) { + if (rnd > FloatType(0)) { + rnd = std::min(params.max, rnd + FloatType(1)); + } else { + rnd = std::max((params.max - params.range), rnd - FloatType(1)); + } + result = Element(rnd); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template +struct RandomUniformFunc> { + + using Element = complex; + + using FloatType = typename std::conditional< + (sizeof(Real) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Real) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType min; + int int_scale; + double pnan; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + FloatType max = 1, + FloatType min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + range(static_cast(max - min_)), + min(static_cast(min_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero >= 0) { + min = (min == FloatType(0)) ? min + FloatType(1): min; + range = (max == FloatType(0)) ? range - FloatType(1): range; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(Real(NAN), Real(NAN)); + } + } + + FloatType rnd_r = random_uniform_float(&rng_state); + FloatType rnd_i = random_uniform_float(&rng_state); + + rnd_r = params.min + params.range * rnd_r; + rnd_i = params.min + params.range * rnd_i; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); + + result = { + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + if (params.exclude_zero >= 0 && + result.real() == Real(0.0) && + result.imag() == Real(0.0)) { + + if (rnd_r > FloatType(0)) { + rnd_r = std::min(params.min + params.range, rnd_r + FloatType(1)); + } else { + rnd_r = std::max((params.min), rnd_r - FloatType(1)); + } + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + +/// Computes a random uniform distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomUniformFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomUniformFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomUniformFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + typename RealType::Type max = Element(1), ///< upper bound of distribution + typename RealType::Type min = Element(0), ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomUniformFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, max, min, bits, pnan, exclude_zero); + + TensorForEach( + view.extent(), + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type max, ///< upper bound of distribution + typename RealType::Type min, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomUniformFunc; + + typename RandomFunc::Params params(seed, max, min, bits, pnan); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random sparse meta +template ///< Element type +struct RandomSparseMetaFunc { + + using FloatType = float; + + using IntType = int32_t; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + int MetaSizeInBits; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), + MetaSizeInBits(MetaSizeInBits_) { + if (MetaSizeInBits_ == 2) { + range = 6; + } + else if (MetaSizeInBits_ == 4) { + range = 2; + } + else { + throw std::invalid_argument("Invalid MetaSizeInBits"); + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomSparseMetaFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element *MetaArray = + (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.range * rnd; + Element meta = MetaArray[(int)rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomSparseMetaFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomSparseMetaFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, MetaSizeInBits); + + TensorForEach( + view.extent(), + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomSparseMetaFunc; + + typename RandomFunc::Params params(seed, MetaSizeInBits); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element diag; + Element other; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element diag_ = Element(1), + Element other_ = Element(0) + ): + view(view_), diag(diag_), other(other_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Updates the tensor + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + params.view.at(coord) = (is_diag ? params.diag : params.other); + } +}; + +// Overwrites the elements of a tensor with a uniform value depending on fill mode +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillPartialFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element element; + FillMode fill_mode; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params(): fill_mode(FillMode::kNone) { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, + Element element_, + FillMode fill_mode_ + ): + view(view_), element(element_), fill_mode(fill_mode_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorFillPartialFunc(Params const ¶ms): params(params) { + + } + + /// Overwrites the element if it is within the covered region. + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool predicate = true; + + switch (params.fill_mode) { + case FillMode::kFull: + predicate = true; + break; + + case FillMode::kLower: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] < coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kUpper: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] > coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kDiagonal: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] != coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kNone: // fall-through + + default: + predicate = false; + break; + } + + if (predicate) { + params.view.at(coord) = params.element; + } + } +}; + + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorClearPartialFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// + static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices"); + + /// Parameters structure + struct Params { + TensorView view{}; + Element element{}; + FillMode fill_mode{FillMode::kNone}; + int alignment{0}; + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorClearPartialFunc(Params const ¶ms): params(params) { + + } + + /// Overwrites the element if it is within the covered region. + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool predicate = true; + + switch (params.fill_mode) { + + case FillMode::kLower: + if ((coord[0] >= coord[1]) || + ((coord[1] - coord[0]) >= params.alignment)) { + predicate = false; + break; + } + break; + + case FillMode::kUpper: + if ((coord[0] <= coord[1]) || + ((coord[0] - coord[1]) >= params.alignment)) { + predicate = false; + break; + } + break; + + case FillMode::kNone: // fall-through + + default: + predicate = false; + break; + } + + if (predicate) { + params.view.at(coord) = params.element; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor everywhere with a unique value for its diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillDiagonal( + TensorView view, ///< destination tensor + Element diag = Element(1), ///< value to write in the diagonal + Element other = Element(0), ///< value to write off the diagonal + cudaStream_t stream = nullptr) { + + typedef detail::TensorFillDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, diag, other), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are +/// not written. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillPartial( + TensorView view, ///< destination tensor + Element element, + FillMode fill_mode, + cudaStream_t stream = nullptr) { + + typedef detail::TensorFillPartialFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, element, fill_mode), + stream + ); +} + +/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side +/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros) +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorClearPartial( + TensorView view, ///< destination tensor + Element element, + FillMode fill_mode, + int alignment, + cudaStream_t stream = nullptr) { + + typedef detail::TensorClearPartialFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params{view, element, fill_mode, alignment}, + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorView view, ///< destination tensor + Element val = Element(0), ///< value to uniformly fill it with + cudaStream_t stream = nullptr) { + + TensorFillDiagonal(view, val, val, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor's diagonal with 1 and 0 everywhere else. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillIdentity( + TensorView view, ///< destination tensor + cudaStream_t stream = nullptr) { + + TensorFillDiagonal(view, Element(1), Element(0), stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element diag; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + Element diag_ = Element(1) + ): + view(view_), diag(diag_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorUpdateDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (is_diag) { + params.view.at(coord) = params.diag; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateDiagonal( + TensorView view, ///< destination tensor + Element diag = Element(1), + cudaStream_t stream = nullptr) { + + typedef detail::TensorUpdateDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, diag), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateOffDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element other; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + Element other_ = Element(0) + ): + view(view_), other(other_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorUpdateOffDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (!is_diag) { + params.view.at(coord) = params.other; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateOffDiagonal( + TensorView view, ///< destination tensor + Element other = Element(1), + cudaStream_t stream = nullptr) { + + typedef detail::TensorUpdateOffDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, other), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillLinearFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Array v; + Element s; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Array const & v_, + Element s_ = Element(0) + ): + view(view_), v(v_), s(s_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillLinearFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element sum = params.s; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + if constexpr (is_complex::value) { + if constexpr (sizeof_bits::value <= 32) { + sum = Element(static_cast>(sum) + + static_cast>(params.v[i]) * static_cast>(coord[i])); + } + } + else if constexpr (sizeof_bits::value <= 32) { + if constexpr (std::numeric_limits::is_integer) { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + else { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + } + else { + sum += params.v[i] * coord[i]; + } + } + + params.view.at(coord) = sum; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillLinear( + TensorView view, ///< destination tensor + Array const & v, + Element s = Element(0), + cudaStream_t stream = nullptr) { + + using Func = detail::TensorFillLinearFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, v, s), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr, + int exclude_zero = -1 ///< If non-negative, excludes 0. + /// Note that setting this flag will result in more 1's, + /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. + ) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + exclude_zero, + stream); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + dist.uniform.pnan, + exclude_zero, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + + using Layout = layout::PackedVectorLayout; + Layout::TensorCoord size(static_cast(capacity)); // -Wconversion + Layout layout = Layout::packed(size); + TensorView view(ptr, layout, size); + + Array c{}; + c[0] = v; + + TensorFillLinear(view, c, s); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillRandom( + Element *ptr, + size_t capacity, + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + BlockFillRandomGaussian( + ptr, + capacity, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + stream); + } + else if (dist.kind == Distribution::Uniform) { + BlockFillRandomUniform( + ptr, + capacity, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + dist.uniform.pnan, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorCopyDiagonalInFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element const *ptr; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Element const *ptr_ + ): + view(view_), ptr(ptr_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorCopyDiagonalInFunc(Params const ¶ms): params(params) { + + } + + /// Only update the diagonal element + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + bool is_diagonal = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[0]) { + is_diagonal = false; + } + } + if (is_diagonal) { + params.view.at(coord) = params.ptr[coord[0]]; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies a diagonal in from host memory without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalIn( + TensorView view, ///< destination tensor + Element const *ptr, ///< dense buffer of elements + cudaStream_t stream = nullptr) { + + using Func = detail::TensorCopyDiagonalInFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorCopyDiagonalOutFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element *ptr; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Element *ptr_ + ): + view(view_), ptr(ptr_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorCopyDiagonalOutFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + bool is_diagonal = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[0]) { + is_diagonal = false; + } + } + if (is_diagonal) { + params.ptr[coord[0]] = params.view.at(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies the diagonal of a tensor into a dense buffer in host memory. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalOut( + Element *ptr, ///< dense buffer of elements + TensorView view, ///< source tensor + cudaStream_t stream = nullptr) { + + using Func = detail::TensorCopyDiagonalOutFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..ba2dfd85c47b8c9450c348de32dccb7f1be9c3c1 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/reference/device/kernel/tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Launches a kernel calling a functor for each element in a tensor's index space. +template +struct TensorForEach { + + /// Constructor performs the operation. + TensorForEach( + Coord size, Params params = Params(), + int grid_size = 0, int block_size = 0, + cudaStream_t stream = nullptr) { + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::TensorForEach)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::TensorForEach<<< grid, block, 0, stream >>>(size, params); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Launches a kernel calling a functor for each element along a tensor's diagonal +template +struct TensorDiagonalForEach { + + /// Constructor performs the operation + TensorDiagonalForEach( + Coord size, Params params = Params(), + int start = 0, int end = -1, + int block_size = 128, cudaStream_t stream = nullptr) { + + if (end < 0) { + end = size.min(); + } + + dim3 block(block_size, 1, 1); + dim3 grid((end - start + block_size - 1) / block_size, 1, 1); + + kernel::TensorDiagonalForEach<<< grid, block, 0, stream >>>( + size, params, start, end); + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockForEach { + + /// Constructor performs the operation. + BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params = typename Func::Params(), + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockForEach)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockForEach<<< grid, block, 0, stream >>>(ptr, capacity, params); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..3e6d7b300f34fec6aec96e72f78427cf677936b4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h @@ -0,0 +1,514 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/detail/linear_to_coordinate.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t size = view.size(); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + + // Fetch element + Element x = view.at(coord); + + // Transform + identity = reduce(identity, transform(x)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + auto size = static_cast(view_A.size()); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + + // Fetch element + Element a = view_A.at(coord); + Element b = view_B.at(coord); + + // Transform + identity = reduce(identity, transform(a, b)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + + +template < + typename ComputeType, + typename ReduceOp, + int kBlockSize = 32 +> +__global__ void TensorTransformReduceFinalize( + ComputeType *workspace, + ComputeType identity, + int workspace_size, + ReduceOp reduce) { + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) { + identity = reduce(identity, workspace[idx]); + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[0] = identity; + } +} + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + cudaStreamSynchronize(stream); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of two tensors, zipped together +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Extents must be equal."); + } + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view_A, view_B, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + cudaStreamSynchronize(stream); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view_A, + view_B, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSq(view, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform, stream, workspace_size); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h new file mode 100644 index 0000000000000000000000000000000000000000..0e3d99ddf845810249f909fbdee4505a0a732c4f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h @@ -0,0 +1,141 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorReLuFunc { + + /// View type + using TensorView = TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element threshold; + + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element threshold_ = Element(0) + ): + view(view_), threshold(threshold_) { + + } + }; + + // + // Data members + // + + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorReLuFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element const & value = params.view.at(coord); + params.view.at(coord) = (value < params.threshold) ? params.threshold : value; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Apply ReLu on a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorReLu( + TensorView view, ///< destination tensor + Element threshold = Element(0)) { ///< ReLu threshold + + using Func = detail::TensorReLuFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, threshold) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..dd11f96bd92f6995590e61665e41a3e830bceacd --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h @@ -0,0 +1,186 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace thread { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level blocked general matrix product. +// +// Note, this is a reference implementation. Performance is not expected to approach peak. +// +template < + typename TensorRefA, + typename TensorRefB, + typename TensorRefC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +struct Gemm { + + using ElementA = typename TensorRefA::Element; + using ElementB = typename TensorRefB::Element; + using ElementC = typename TensorRefC::Element; + + // + // Data members + // + + /// Tile for A operand + ElementA A_tile[OutputTile::kColumn]; + + /// Tile for B operand + ElementB B_tile[OutputTile::kRow]; + + /// Tile for Accumulator + AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow]; + + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + Gemm(AccumulatorType initial_accum = AccumulatorType(0)) { + + // Clear fetch registers + for (int i = 0; i < OutputTile::kColumn; ++i) { + A_tile[i] = ElementA(0); + } + + for (int j = 0; j < OutputTile::kRow; ++j) { + B_tile[j] = ElementB(0); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kColumn; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kRow; ++i) { + accum[j][i] = initial_accum; + } + } + } + + /// Computes a matrix product + CUTLASS_HOST_DEVICE + Gemm & multiply_add( + gemm::GemmCoord problem_size, + TensorRefA tensor_a, + TensorRefB tensor_b, + MatrixCoord output_coord = MatrixCoord()) { + + InnerProductOp inner_product_op; + + // Loop over the GEMM K dimension + CUTLASS_PRAGMA_NO_UNROLL + for (int k = 0; k < problem_size.k(); ++k) { + + // Fetch a slice of the A matrix + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kColumn; ++i) { + if (output_coord.row() + i < problem_size.m()) { + A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k)); + } + } + + // Fetch a slice of the B matrix + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kRow; ++j) { + if (output_coord.column() + j < problem_size.n()) { + B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j)); + } + } + + // Compute an accumulated matrix product + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kRow; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kColumn; ++i) { + accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]); + } + } + } + + return *this; + } + + /// Performs linear scaling of matrix product and updates output tensor + CUTLASS_HOST_DEVICE + Gemm & epilogue( + gemm::GemmCoord problem_size, + ScalarType alpha, + ScalarType beta, + TensorRefC tensor_c, + TensorRefC tensor_d, + MatrixCoord output_coord = MatrixCoord()) { + + ConvertOp convert_op; + + // Update the output tensor + for (int j = 0; j < OutputTile::kRow; ++j) { + for (int i = 0; i < OutputTile::kColumn; ++i) { + MatrixCoord coord = output_coord + MatrixCoord(i, j); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[j][i]) + + beta * ScalarType(tensor_c.at(coord)) + ); + } + } + } + + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..57443325629ea4e5d855fe18f94c73b10a71a73a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp @@ -0,0 +1,782 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for CONV in host-side code. +*/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +#include "cute/tensor.hpp" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<5>(activation)) && + (n_ >= 0 && n_ < size<4>(activation)) && + (d_ >= 0 && d_ < size<3>(activation)) && + (h_ >= 0 && h_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<4>(activation)) && + (n_ >= 0 && n_ < size<3>(activation)) && + (h_ >= 0 && h_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<3>(activation)) && + (n_ >= 0 && n_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +} // namespace detail + +template< + class ElementAcc_, + class ElementScalar_, + class ElementCompute_, + class ElementC_, + class ElementOut_, + bool ResidualAdd_, + class TensorAlpha_, + class TensorBeta_, + class TensorBias_, + class ActivationFunctor_ = cutlass::epilogue::thread::Identity +> +struct ConvEpilogueFusionParams { + using ElementAcc = ElementAcc_; + using ElementScalar = ElementScalar_; + using ElementCompute = ElementCompute_; + using ElementC = ElementC_; + using ElementOut = ElementOut_; + using TensorAlpha = TensorAlpha_; + using TensorBeta = TensorBeta_; + using TensorBias = TensorBias_; + using ActivationFunctor = ActivationFunctor_; + static constexpr bool ResidualAdd = ResidualAdd_; // Source added after activation + + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorAlpha tensor_alpha{}; + TensorBeta tensor_beta{}; + TensorBias tensor_bias{}; +}; + +template< + cutlass::conv::Operator ConvOp, + int NumSpatialDims, + class TensorA, + class TensorB, + class TensorC, + class TensorD, + class ShapePadding, + class StrideTraversal, + class ShapeDilation, + class EpilogueFusionParams +> +struct ConvReferenceImpl { + // Hard code accumlulator type to float to avoid data lost in accumulating add. + using ElementAcc = cutlass::platform::conditional_t, double, float>; + using ElementC = typename EpilogueFusionParams::ElementC; + using ElementOut = typename EpilogueFusionParams::ElementOut; + using ElementScalar = typename EpilogueFusionParams::ElementScalar; + using ElementCompute = typename EpilogueFusionParams::ElementCompute; + using ElementBias = typename EpilogueFusionParams::TensorBias::value_type; + using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor; + + // Input related converter + NumericConverter acc_converter; + NumericConverter residual_converter; + NumericConverter bias_converter; + // Scale related converter + NumericConverter scale_converter; + // Output related converter + NumericConverter output_converter; + + EpilogueFusionParams& epi_fusion_params_; + TensorA const& tensor_a_; + TensorB const& tensor_b_; + TensorC const& tensor_c_; + TensorD& tensor_d_; + + ShapePadding const& padding_; + StrideTraversal const& tstride_; + ShapeDilation const& dilation_; + + // Epilogue activation operation + ActivationFunctor epi_activation; + + ConvReferenceImpl( + TensorA const& tensor_a, + TensorB const& tensor_b, + TensorC const& tensor_c, + TensorD& tensor_d, + ShapePadding const& padding, + StrideTraversal const& tstride, + ShapeDilation const& dilation, + EpilogueFusionParams& epi_fusion_params) + : tensor_a_(tensor_a), + tensor_b_(tensor_b), + tensor_c_(tensor_c), + tensor_d_(tensor_d), + padding_(padding), + tstride_(tstride), + dilation_(dilation), + epi_fusion_params_(epi_fusion_params) + { + static_assert(rank(ShapePadding{}) == rank(ShapeDilation{})); + static_assert(rank(ShapePadding{}) == rank(StrideTraversal{})); + } + + void compute_reference() { + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + fprop_reference(cute::Int{}); + } + else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + dgrad_reference(cute::Int{}); + } + else { + wgrad_reference(cute::Int{}); + } + } + +private: + // Specialization for 1D fprop kernel + void fprop_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, w, c, g)) { + auto a = tensor_a_(c, w, n, g); + auto b = tensor_b_(c, s, k, g); + accumulator += ElementAcc(a * b); + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + } + tensor_d_(k, q, n, g) = output_converter(output); + } + } + } + } + + } + + // Specialization for 2D fprop kernel + void fprop_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = size<3>(tensor_d_); + int32_t P = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c, g)) { + auto a = tensor_a_(c, w, h, n, g); + auto b = tensor_b_(c, s, r, k, g); + accumulator += ElementAcc(a * b); + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + } + tensor_d_(k, q, p, n, g) = output_converter(output); + } + } + } + } + } + + } + + // Specialization for 3D fprop kernel + void fprop_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = size<4>(tensor_d_); + int32_t Z = size<3>(tensor_d_); + int32_t P = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t T = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t z = 0; z < Z; ++z) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c, g)) { + auto a = tensor_a_(c, w, h, d, n, g); + auto b = tensor_b_(c, s, r, t, k, g); + accumulator += ElementAcc(a * b); + } + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + } + tensor_d_(k, q, p, z, n, g) = output_converter(output); + } + } + } + } + } + } + + } + + // Specialization for 1D dgrad kernel + void dgrad_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, n, g) * tensor_b_(c, s, k, g)); + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + } + tensor_d_(c, w, n, g) = output_converter(output); + } + } + } + } + + } + + // Specialization for 2D dgrad kernel + void dgrad_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = size<3>(tensor_d_); + int32_t H = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (p % cute::get<1>(tstride_) == 0) { + p /= cute::get<1>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, p, n, g) * tensor_b_(c, s, r, k, g)); + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + } + + tensor_d_(c, w, h, n, g) = output_converter(output); + } + } + } + } + } + + } + + // Specialization for 3D dgrad kernel + void dgrad_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = size<4>(tensor_d_); + int32_t D = size<3>(tensor_d_); + int32_t H = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<4>(tensor_b_); + int32_t T = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t d = 0; d < D; ++d) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); + int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (p % cute::get<1>(tstride_) == 0) { + p /= cute::get<1>(tstride_); + } else { + continue; + } + + if (z % cute::get<2>(tstride_) == 0) { + z /= cute::get<2>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, p, z, n, g) * tensor_b_(c, s, r, t, k, g)); + } + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + } + tensor_d_(c, w, h, d, n, g) = output_converter(output); + } + } + } + } + } + } + + } + + // Specialization for 1D wgrad kernel + void wgrad_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, n, g); + auto xformed_act = + tensor_a_(k, q, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + } + tensor_d_(c, s, k, g) = output_converter(output); + } + } + } + } + } + + // Specialization for 2D wgrad kernel + void wgrad_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t R = size<2>(tensor_d_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, h, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, n, g); + auto xformed_act = + tensor_a_(k, q, p, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + } + tensor_d_(c, s, r, k, g) = output_converter(output); + } + } + } + } + } + } + + // Specialization for 3D wgrad kernel + void wgrad_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = + size<4>(tensor_a_); + int32_t Z = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t T = size<3>(tensor_d_); + int32_t R = size<2>(tensor_d_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0 ; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t z = 0; z < Z; ++z) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, d, n, g); + auto xformed_act = + tensor_a_(k, q, p, z, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + } + tensor_d_(c, s, r, t, k, g) = output_converter(output); + } + } + } + } + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..73298e5794f0f2658ef18fb3f46466c400fc831e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h @@ -0,0 +1,802 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Reference implementation for convolution in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Forward propagation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// y = conv2d(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + int group_idx = k / (problem_size.K / problem_size.groups); + int channels_per_group = problem_size.C / problem_size.groups; + + ElementAccumulator acc = ElementAccumulator(); + + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < channels_per_group; ++c) { + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { + + ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group}); + ElementB b = tensor_w.at({k, r, s, c}); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k)); + } + + tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } +} + +/// Depthwise-separable convolution +template , + typename InnerProductOp = multiply_add> +void Depsep_Fprop(cutlass::TensorView tensor_A, + cutlass::TensorView tensor_B, + cutlass::TensorView tensor_C, + cutlass::TensorView tensor_D, + ElementCompute alpha, + ElementCompute beta, + cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(), + cutlass::Coord<2> conv_stride = cutlass::Coord<2>(), + cutlass::Coord<2> dilation = cutlass::Coord<2>(), + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < tensor_C.extent().n(); ++n) { + for (int p = 0; p < tensor_C.extent().h(); ++p) { + for (int q = 0; q < tensor_C.extent().w(); ++q) { + for (int g = 0; g < tensor_C.extent().c(); ++g) { + ElementAccumulator acc = ElementAccumulator(); + for (int r = 0; r < tensor_B.extent().h(); ++r) { + for (int s = 0; s < tensor_B.extent().w(); ++s) { + + // input activation H and W + int h = p * conv_stride[0] - padding[0] + r * dilation[0]; + int w = q * conv_stride[1] - padding[2] + s * dilation[1]; + + if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) { + ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g)); + + ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation) + ? tensor_B.at(cutlass::make_Coord(g, r, s, 0)) + : tensor_B.at(cutlass::make_Coord( + g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0)); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g)); + tensor_D.at(cutlass::make_Coord(n, p, q, g)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dgrad / Deconv +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dDgrad( + cutlass::conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + bool is_deconv = false) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int h = 0; h < problem_size.H; ++h) { + for (int w = 0; w < problem_size.W; ++w) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int k = 0; k < problem_size.K; ++k) { + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; + + if (p >= 0 && (p % problem_size.stride_h) == 0 && + q >= 0 && (q % problem_size.stride_w) == 0) { + + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; +#if 0 + std::cout << "row:" + << n * problem_size.H * problem_size.W + + h * problem_size.W + + w << " " + << "n, p, q: (" + << n << ", " + << p << ", " + << q << ") * " + << "r, s: (" + << r << ", " + << s << ") [" + << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]" + << std::endl; +#endif + if (p < problem_size.P && q < problem_size.Q) { + + ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, r, s, c)); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + + } // for (K) + } // for (S) + } // for (R) + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c)); + } + + tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (W) + } // for (H) + } // for (N) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Wgrad +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dWgrad( + cutlass::conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta) { + + InnerProductOp inner_product_op; + ConvertOp convert_op; + + // Apply MMA and accumulate ElementAccumulator + for (int k = 0; k < problem_size.K; ++k) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + + cutlass::Tensor4DCoord b_coord; + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + b_coord = make_Coord( + n, + p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, + q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, + c); + + if (b_coord.h() < problem_size.H && b_coord.h() >= 0 && + b_coord.w() < problem_size.W && b_coord.w() >= 0) { + + ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k))); + ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); + acc = inner_product_op(a, b, acc); + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c)); + } + + tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (S) + } // for (R) + } // for (K) +} + +/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2d( + conv::Operator convolutional_operator, + conv::Conv2dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + Conv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + case conv::Operator::kDeconv: + case conv::Operator::kDgrad: + Conv2dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); + break; + + case conv::Operator::kWgrad: + Conv2dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + default: + break; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// 3D convolution +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// y = conv3d(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int z = 0; z < problem_size.Z; ++z) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (d >= 0 && d < problem_size.D && + h >=0 && h < problem_size.H && + w >= 0 && w < problem_size.W) { + + ElementA a = tensor_x.at({n, d, h, w, c}); + ElementB b = tensor_w.at({k, t, r, s, c}); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k)); + } + + tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dgrad / Deconv +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dDgrad( + cutlass::conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + bool is_deconv = false) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int d = 0; d < problem_size.D; ++d) { + for (int h = 0; h < problem_size.H; ++h) { + for (int w = 0; w < problem_size.W; ++w) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int k = 0; k < problem_size.K; ++k) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d; + int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; + + if (z >= 0 && (z % problem_size.stride_d) == 0 && + p >= 0 && (p % problem_size.stride_h) == 0 && + q >= 0 && (q % problem_size.stride_w) == 0) { + + z = z / problem_size.stride_d; + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { + + ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + + } // for (K) + } // for (S) + } // for (R) + } // for (T) + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c)); + } + + tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (W) + } // for (H) + } // for (D) + } // for (N) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Wgrad +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dWgrad( + cutlass::conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta) { + + InnerProductOp inner_product_op; + ConvertOp convert_op; + + // Apply MMA and accumulate ElementAccumulator + for (int k = 0; k < problem_size.K; ++k) { + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int n = 0; n < problem_size.N; ++n) { + for (int z = 0; z < problem_size.Z; ++z) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + Tensor5DCoord b_coord = make_Coord( + n, + z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d, + p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, + q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, + c); + + if (b_coord.d() < problem_size.D && b_coord.d() >= 0 && + b_coord.h() < problem_size.H && b_coord.h() >= 0 && + b_coord.w() < problem_size.W && b_coord.w() >= 0) { + + ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k))); + ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); + + acc = inner_product_op(a, b, acc); + } + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c)); + } + + tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + } // for (K) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3d( + conv::Operator convolutional_operator, + conv::Conv3dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + Conv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + case conv::Operator::kDeconv: + case conv::Operator::kDgrad: + Conv3dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); + break; + + case conv::Operator::kWgrad: + Conv3dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + default: + break; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h new file mode 100644 index 0000000000000000000000000000000000000000..12ead83354b785096e8029b49f1ac353d5ce5f82 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h @@ -0,0 +1,66 @@ + +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/util/reference/host/tensor_reduce.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorRelativeErrorMetric( + TensorView view_A_computed, + TensorView view_B_reference, + ComputeType identity = ComputeType() +) { + + return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) / + cutlass::reference::host::TensorNorm(view_B_reference, identity); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..2afee7b36d9822cc196f0f167f9dbec4c295d1a6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h @@ -0,0 +1,531 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" + +namespace cutlass { +namespace reference { +namespace host { + +template +struct CastIfScalar { + static Out cast(In in) { + return Out(in); + } +}; + +template +struct CastIfScalar, In> { + typedef cutlass::complex Out; + static Out cast(In in) { + return Out(static_cast(in)); + } +}; + +template +struct CastIfScalar, cutlass::complex> { + typedef cutlass::complex Out; + typedef cutlass::complex In; + static Out cast(In in) { + return Out(in); + } +}; + +template +Out cast_if_scalar(In in) { + return CastIfScalar::cast(in); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_gemm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Gemm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add-saturate +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for XOR-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +/// Partial specialization for AND-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Batched GEMM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a batch of GEMMs over a set of matrices of common dimension. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c, + AccumulatorType initial_accum) { + + typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin(); + typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin(); + typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin(); + + for (int batch = 0; + batch < batch_count; + ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) { + + Gemm + gemm; + + gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it, + initial_accum); + } +} + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c) { + + BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..221a6040854a74ce465af7b021bbbfae9b96a90b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h @@ -0,0 +1,210 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass/tensor_view.h" + +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ElementD = ElementC +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..507c37d9eb5a8c998f1075d547e8430b2edc5685 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h @@ -0,0 +1,228 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + complex accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC d_ij = tensor_c.at(coord); + + complex src{ + ScalarType(d_ij.real()), + ScalarType(d_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag()); + + tensor_d.at(coord) = d_ij; + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd54dc6e378d0d0f0549ec922da8357841ac558f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -0,0 +1,916 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GETT in host-side code. +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/gemm.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/relatively_equal.h" + +#include "cute/tensor.hpp" +#include "cute/pointer.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +template +struct ElementTraits { + using type = T; +}; + +template +struct ElementTraits().get()), void> > > { + using type = decltype(std::declval().get()); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////// +// +// Gett Mainloop Parameters +// +/////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorB_ // (N, K, L) + + , class TensorSfA_ = TensorA_, + class TensorSfB_ = TensorB_ + +> +struct GettMainloopParams { + using ElementAccumulator = ElementAccumulator_; + using TensorA = TensorA_; + using TensorB = TensorB_; + using EngineA = typename TensorA::engine_type; + using LayoutA = typename TensorA::layout_type; + using EngineB = typename TensorB::engine_type; + using LayoutB = typename TensorB::layout_type; + + TensorA A{}; + TensorB B{}; + + ComplexTransform transform_A = ComplexTransform::kNone; + ComplexTransform transform_B = ComplexTransform::kNone; + + + using TensorSfA = TensorSfA_; + using TensorSfB = TensorSfB_; + using EngineSfA = typename TensorSfA::engine_type; + using LayoutSfA = typename TensorSfA::layout_type; + using EngineSfB = typename TensorSfB::engine_type; + using LayoutSfB = typename TensorSfB::layout_type; + TensorSfA_ SfA{}; + TensorSfB_ SfB{}; + + + GettMainloopParams() {} + + GettMainloopParams(TensorA tensor_A, TensorB tensor_B) + : A(tensor_A), B(tensor_B) {} + + + GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) + : A(tensor_A), SfA(tensor_SfA), + B(tensor_B), SfB(tensor_SfB) {} + + +}; + + + +//////////////////////////////////////////////////////////////////////// +// +// Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels +// +//////////////////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorSfA_, // (M, K, L) + class TensorB_, // (N, K, L) + class TensorSfB_ // (N, K, L) +> +struct GettBlockScalingMainloopParams : public GettMainloopParams { + using Base = GettMainloopParams; + using ElementAccumulator = typename Base::ElementAccumulator; + using TensorA = typename Base::TensorA; + using TensorB = typename Base::TensorB; + using EngineA = typename Base::EngineA; + using LayoutA = typename Base::LayoutA; + using EngineB = typename Base::EngineB; + using LayoutB = typename Base::LayoutB; + ComplexTransform transform_A = Base::transform_A; + ComplexTransform transform_B = Base::transform_B; + + using TensorSfA = typename Base::TensorSfA; + using TensorSfB = typename Base::TensorSfB; + using EngineSfA = typename Base::EngineSfA; + using LayoutSfA = typename Base::LayoutSfA; + using EngineSfB = typename Base::EngineSfB; + using LayoutSfB = typename Base::LayoutSfB; + + GettBlockScalingMainloopParams() {} + + GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) + : Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {} + + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class SfStrategy { + None = 0, + SfDGen = 1 +}; + + +/////////////////////////////////////////////////////////// +// +// Gett Epilogue Parameters +// +/////////////////////////////////////////////////////////// + +template< + class ElementScalar_, + class ElementScalingFactor_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, // (M, N, L) + class TensorD_, // (M, N, L) + class VectorBias_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) + class TensorAux_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, N, L) + class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) + class VectorBeta_ = VectorAlpha_, // (M, 1) + class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + class TensorSFD_ = TensorD_, + class SFD_VectorSize_ = cute::Int<0>, + class BiasBinaryOp_ = cutlass::plus, + bool PerColumnBias_ = false + , + SfStrategy SfGenStrategy_ = SfStrategy::None +> +struct GettEpilogueParams { + using ElementScalar = ElementScalar_; + using ElementScalingFactor = ElementScalingFactor_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using TensorC = TensorC_; + using TensorD = TensorD_; + using TensorAux = TensorAux_; + using VectorBias = VectorBias_; + using VectorAlpha = VectorAlpha_; + using VectorBeta = VectorBeta_; + using TensorSFD = TensorSFD_; + using SFD_VectorSize = SFD_VectorSize_; + using ActivationFunctor = ActivationFunctor_; + using BiasBinaryOp = BiasBinaryOp_; + + using EngineC = typename TensorC::engine_type; + using LayoutC = typename TensorC::layout_type; + using EngineD = typename TensorD::engine_type; + using LayoutD = typename TensorD::layout_type; + using EngineSfD = typename TensorSFD::engine_type; + using LayoutSfD = typename TensorSFD::layout_type; + static constexpr bool PerColumnBias = PerColumnBias_; + static constexpr SfStrategy SfGenStrategy = SfGenStrategy_; + + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorC C{}; + TensorD D{}; + VectorBias Bias{}; + TensorAux Aux{}; + VectorAlpha Valpha{}; + VectorBeta Vbeta{}; + TensorSFD SfD{}; + ElementCompute st = ElementCompute(1); + + ElementAccumulator* abs_max_D = nullptr; + ElementAccumulator* abs_max_Aux = nullptr; + + ElementScalingFactor scale_a = ElementScalingFactor(1); + ElementScalingFactor scale_b = ElementScalingFactor(1); + ElementScalingFactor scale_c = ElementScalingFactor(1); + ElementScalingFactor scale_d = ElementScalingFactor(1); + ElementScalingFactor scale_aux = ElementScalingFactor(1); + + bool beta_per_channel_scaling = false; + GettEpilogueParams() {} + + GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) + : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {} + + + GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) + : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {} + + + GettEpilogueParams( + ElementScalar alpha, ElementScalar beta, + TensorC tensor_C, TensorD tensor_D, + VectorBias bias, TensorAux tensor_aux, + VectorAlpha vector_alpha, VectorBeta vector_beta) + : alpha(alpha), beta(beta), + C(tensor_C), D(tensor_D), + Bias(bias), Aux(tensor_aux), + Valpha(vector_alpha), Vbeta(vector_beta) {} +}; + + + +//////////////////////////////////////////////////////////////////////// +// +// Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels +// +//////////////////////////////////////////////////////////////////////// + +template< + class ElementScalar_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, + class TensorD_, + class TensorSfD_ = TensorD_, + class SFD_VectorSize_ = cute::Int<0>, + SfStrategy SfGenStrategy_ = SfStrategy::None +> +struct GettBlockScalingEpilogueParams : public GettEpilogueParams< + ElementScalar_, // ElementScalar + ElementScalar_, // ElementScalingFactor + ElementAccumulator_, // ElementAccumulator + ElementCompute_, // ElementCompute + TensorC_, // TensorC (M, N, L) + TensorD_, // TensorD (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) + cutlass::epilogue::thread::Identity, // + TensorSfD_, // TensorSfD + SFD_VectorSize_, // SFD_VectorSize + cutlass::plus, // class BiasBinaryOp_ = + false, //PerColumnBias_ + SfGenStrategy_ // SfGenStrategy + > { + using Base = GettEpilogueParams< + ElementScalar_, // ElementScalar + ElementScalar_, // ElementScalingFactor + ElementAccumulator_, // ElementAccumulator + ElementCompute_, // ElementCompute + TensorC_, // TensorC (M, N, L) + TensorD_, // TensorD (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) + cutlass::epilogue::thread::Identity, // + TensorSfD_, // TensorSfD + SFD_VectorSize_, // SFD_VectorSize + cutlass::plus, // BiasBinaryOp + false, // PerColumnBias + SfGenStrategy_ // SfGenStrategy + >; + using ElementScalar = typename Base::ElementScalar; + using ElementScalingFactor = typename Base::ElementScalingFactor; + using ElementAccumulator = typename Base::ElementAccumulator; + using ElementCompute = typename Base::ElementCompute; + using TensorC = typename Base::TensorC; + using TensorD = typename Base::TensorD; + using TensorAux = typename Base::TensorAux; + using VectorBias = typename Base::VectorBias; + using VectorAlpha = typename Base::VectorAlpha; + using VectorBeta = typename Base::VectorBeta; + using TensorSFD = typename Base::TensorSFD; + using SFD_VectorSize = typename Base::SFD_VectorSize; + using ActivationFunctor = typename Base::ActivationFunctor; + using BiasBinaryOp = typename Base::BiasBinaryOp; + + using EngineC = typename Base::EngineC; + using LayoutC = typename Base::LayoutC; + using EngineD = typename Base::EngineD; + using LayoutD = typename Base::LayoutD; + using EngineSfD = typename Base::EngineSfD; + using LayoutSfD = typename Base::LayoutSfD; + static constexpr bool PerColumnBias = Base::PerColumnBias; + static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy; + + GettBlockScalingEpilogueParams() {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) + : Base(alpha, beta, tensor_C, tensor_D) {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD) + : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) + : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {} +}; + + + + + +/////////////////////////////////////////////////////////// +// +// Generic Gett 3x Implementation +// +/////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +void compute_1d_scaling_factor_and_quantized_output( + EpilogueParams const& epilogue_params, + TensorD &tensor_D, + TensorSFD &tensor_SfD, + int64_t m, + int64_t n, + int64_t l, + ElementCompute (&acc)[kBlockM][kBlockN]) +{ + using ElementD = typename ElementTraits::type; + using ElementSfD = typename ElementTraits::type; + + int const M = cute::size<0>(tensor_D.layout()); + int const N = cute::size<1>(tensor_D.layout()); + int const L = cute::size<2>(tensor_D.layout()); + + auto mul = cutlass::multiplies{}; + auto div = divides{}; + // Get FP max + ElementCompute fp_max = ElementCompute(std::numeric_limits::max()); + float scale_down_factor = div(1.0f, fp_max); + // Get st' = st / FP max + ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor); + + absolute_value_op abs_op; + maximum_with_nan_propogation max_op; + + if constexpr (cute::is_constant<1, decltype(cute::stride<0,0,1>(tensor_SfD))>::value) { + // MN major output + int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize); + // Col major output + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { + int64_t col = n + n_b; + + /// Step1: get max across a vector + ElementCompute accum_max = ElementCompute(0); + for (int v = 0; v < kVectorSize; v++) { + int accum_row = v_b * kVectorSize + v; + int64_t output_row = accum_row + m; + if (output_row < M && col < N) { + accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b])); + } + } + + /// Step2: Compute Scale + ElementCompute pvscale = mul(accum_max, st_scaled_down); + ElementSfD qpvscale = static_cast(pvscale); + // Store the Scaling Factors + int64_t sf_row = m + kVectorSize * v_b; + if (sf_row < M && col < N) { + tensor_SfD(sf_row, col, l) = qpvscale; + } + + /// Step3: Compute quantized output values + ElementCompute qpvscale_up = NumericConverter{}(qpvscale); + // Get float reciprocal + ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); + ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Store the intermediate_accum + for (int v = 0; v < kVectorSize; v++) { + int accum_row = v_b * kVectorSize + v; + int64_t output_row = accum_row + m; + if (output_row < M && col < N) { + acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale); + } + } + } + } + } + else { + int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize); + // row major output + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { + int64_t row = m + m_b; + + /// Step1: get max across a vector + ElementCompute accum_max = ElementCompute(0); + for (int v = 0; v < kVectorSize; v++) { + int accum_col = v_b * kVectorSize + v; + int64_t output_col = accum_col + n; + if (row < M && output_col < N) { + accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col])); + } + } + + /// Step2: Compute Scale + ElementCompute pvscale = mul(accum_max, st_scaled_down); + ElementSfD qpvscale = static_cast(pvscale); + // Store the Scaling Factors + int64_t sf_col = n + kVectorSize * v_b; + + if (row < M && sf_col < N) { + tensor_SfD(row, sf_col, l) = qpvscale; + } + + /// Step3: Compute quantized output values + ElementCompute qpvscale_up = NumericConverter{}(qpvscale); + // Get float reciprocal + ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); + ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Store the intermediate_accum + for (int v = 0; v < kVectorSize; v++) { + int accum_col = v_b * kVectorSize + v; + int64_t output_col = accum_col + n; + if (row < M && output_col < N) { + acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale); + } + } + } + } + } +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - General Tensor-Tensor contraction reference kernel +template < + class MainloopParams, + class EpilogueParams +> +void Gett( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + gett_epilogue(epilogue_params, m, n, l, acc); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Mainloop +template +void gett_mainloop( + MainloopParams const& mainloop_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); + static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementA = typename ElementTraits::type; + using ElementB = typename ElementTraits::type; + + + using ElementSFA = typename ElementTraits::type; + using ElementSFB = typename ElementTraits::type; + + + using RingOp = multiply_add; + RingOp fma_op; + + // Zero out accumulators + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Compute on this k-block + for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { + // Load A + ElementAccumulator a_frag[kBlockM]; + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); + + + if constexpr (not cute::is_same_v){ + // Load SFA + auto sfa = static_cast(mainloop_params.SfA(m + m_b, k, l)); + a_frag[m_b] *= sfa; + } + + + if (mainloop_params.transform_A == ComplexTransform::kConjugate) { + a_frag[m_b] = conj(a_frag[m_b]); + } + } else { + a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Load B + ElementAccumulator b_frag[kBlockN]; + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); + + + if constexpr (not cute::is_same_v){ + // Load SFB + auto sfb = static_cast(mainloop_params.SfB(n + n_b, k, l)); + b_frag[n_b] *= sfb; + } + + + if (mainloop_params.transform_B == ComplexTransform::kConjugate) { + b_frag[n_b] = conj(b_frag[n_b]); + } + } else { + b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // do compute + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]); + } + } + + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Epilogue +template +void gett_epilogue( + EpilogueParams const& epilogue_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); + static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementCompute = typename EpilogueParams::ElementCompute; + using ElementC = typename EpilogueParams::TensorC::value_type; + using ElementD = typename EpilogueParams::TensorD::value_type; + using ElementSfD = typename EpilogueParams::TensorSFD::value_type; + using ElementAux = typename EpilogueParams::TensorAux::value_type; + using ElementBias = typename EpilogueParams::VectorBias::value_type; + using ElementScalar = typename EpilogueParams::ElementScalar; + using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; + using ActivationFunctor = typename EpilogueParams::ActivationFunctor; + using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; + + constexpr bool PerColBias = EpilogueParams::PerColumnBias; + constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy; + + constexpr bool IsScalingAndAmaxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsScalingAndAmaxAuxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsReLUAuxNeeded = + (cute::is_same_v> or + cute::is_same_v>) and + cute::is_same_v; + constexpr bool UseReLU = + cute::is_same_v>; // Treat Clamp as ReLU + + constexpr bool IsBackpropFusion = + cute::is_same_v> or + cute::is_same_v>; + + // Input related converter + NumericConverter accumulator_converter; + NumericConverter source_converter; + NumericConverter bias_converter; + [[maybe_unused]] NumericConverter aux_source_converter; + + // Scale related converter + NumericConverter scale_converter; + NumericConverter scaling_factor_converter; + + // Abs max converter + [[maybe_unused]] NumericConverter abs_max_output_converter; + + // Output related converter + NumericConverter destination_converter; + [[maybe_unused]] NumericConverter aux_destination_converter; + NumericConverter dBias_converter; + + // Epilogue operations + multiply_add epilogue_fma; + multiplies mul; + plus add; + + // Activation operation + ActivationFunctor activation; + + // Bias binary operation + BiasBinaryOp bias_op; + + // Do conversion + ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); + ElementCompute converted_beta = scale_converter(epilogue_params.beta); + ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); + ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); + ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); + ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); + ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); + + // Init local var + [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); + [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); + + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + converted_beta = mul(converted_beta, converted_scale_c); + + ElementCompute inter_accum[kBlockM][kBlockN]; + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + ElementCompute local_dBias = ElementCompute(0); + + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + // Convert every type to ElementCompute first, do compute, convert to output type, write it out + ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); + // vector alpha + if (raw_pointer_cast(epilogue_params.Valpha.data())) { + converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l)); + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + } + ElementCompute output = mul(converted_alpha, converted_acc); + + if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { + ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); + output = bias_op(output, converted_bias); + } + + if (raw_pointer_cast(epilogue_params.C.data())) { + ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); + // vector beta + if (epilogue_params.Vbeta.data()) { + converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l)); + converted_beta = mul(converted_beta, converted_scale_c); + } + output = epilogue_fma(converted_beta, converted_src, output); + } + + if constexpr (IsBackpropFusion) { + ElementAux aux_input = ElementAux(0); + if (raw_pointer_cast(epilogue_params.Aux.data())) { + aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); + } + + output = activation(output, aux_source_converter(aux_input)); + local_dBias = add(local_dBias, output); + } + else { + if (raw_pointer_cast(epilogue_params.Aux.data())) { + auto aux_output = output; + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); + aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); + } + + if constexpr (IsReLUAuxNeeded) { + epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); + } else { + epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); + } + } + + if constexpr (UseReLU) { + cutlass::epilogue::thread::ReLU relu; + output = relu(output); + } + else { + output = activation(output); + } + } + + if constexpr (IsScalingAndAmaxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_output = amax_op(local_abs_max_output, output); + output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); + } + + inter_accum[m_b][n_b] = ElementCompute(output); + } + } // n_b + + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { + if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { + ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); + local_dBias = add(local_dBias, converted_dBias); + epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); + } + } + } // m_b + + if constexpr ( + SfGenStrategy == SfStrategy::SfDGen + ) { + // 1d scale factor generation + constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{}; + if (epilogue_params.SfD.data() != nullptr) { + compute_1d_scaling_factor_and_quantized_output(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum); + } + } + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); + } + } + } + +#if defined(_OPENMP) + #pragma omp critical(Abs_Max_Data_Update) +#endif + { + if constexpr (IsScalingAndAmaxOutputNeeded) { + if (epilogue_params.abs_max_D) { + *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); + } + } + + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + if (epilogue_params.abs_max_Aux) { + *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +auto make_layout_rank3(const TensorType& tensor) { + // append a batch mode of size 1 if we do not have tensors that are rank 3 + return make_layout( + make_shape(cute::get<0>(tensor.shape()), cute::get<1>(tensor.shape()), cute::Int<1>{}), + make_stride(cute::get<0>(tensor.stride()), cute::get<1>(tensor.stride()), int64_t(cosize(tensor.layout())))); +} + +/// GEMM - General Matrix-Matrix contraction without conjugation options +template < + class MainloopParams, + class EpilogueParams +> +void Gemm3x( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + using namespace cute; + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); + + if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) { + cute::Layout layout_A = make_layout_rank3(mainloop_params.A); + cute::Layout layout_B = make_layout_rank3(mainloop_params.B); + cute::Layout layout_C = make_layout_rank3(epilogue_params.C); + cute::Layout layout_D = make_layout_rank3(epilogue_params.D); + cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux); + cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias); + cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha); + cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta); + + auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); + auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); + auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); + auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); + auto TensorAux = make_tensor(epilogue_params.Aux.data(), layout_Aux); + auto VectorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias); + auto VectorAlpha = make_tensor(epilogue_params.Valpha.data(), layout_Valpha); + auto VectorBeta = make_tensor(epilogue_params.Vbeta.data(), layout_Vbeta); + + // Reconstruct mainloop params + GettMainloopParams + mainloop_params_converted{TensorA, + TensorB, + mainloop_params.transform_A, + mainloop_params.transform_B}; + + // Reconstruct epilogue params + GettEpilogueParams + epilogue_params_converted{epilogue_params.alpha, + epilogue_params.beta, + TensorC, + TensorD, + VectorBias, + TensorAux, + VectorAlpha, + VectorBeta, + epilogue_params.abs_amax_D, + epilogue_params.abs_amax_Aux, + epilogue_params.scale_a, + epilogue_params.scale_b, + epilogue_params.scale_c, + epilogue_params.scale_d, + epilogue_params.scale_aux + }; + + Gett(mainloop_params_converted, epilogue_params_converted); + } + else { + // if we already have a batch mode, just pass it through + Gett(mainloop_params, epilogue_params); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h new file mode 100644 index 0000000000000000000000000000000000000000..67867533d5783b6e0047ac2110dc47adaa277e25 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for Rank 2k update in host-side code. + + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_rank2k( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + static_assert( + FillModeC == FillMode::kLower || + FillModeC == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower), + std::greater_equal, + std::less_equal>::type; + + // Note: batch is ignored. + // Note: M is same as N for Rank 2k update + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < N; row_block += Nblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Nblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < N && col < N && compare_op(row, col)) + { + + // A x B^T + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b_t(cast_if_scalar(b_t)); + + accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]); + + // B x A^T + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType compute_b(cast_if_scalar(b)); + ComputeType compute_a_t(cast_if_scalar(a_t)); + + accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < N && col < N && + ( (FillModeC == FillMode::kLower && row >= col) || + (FillModeC == FillMode::kUpper && row <= col) ) + ) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_rank2k( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_rank2k( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Rank2K; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Rank2K { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_rank2k>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_rank2k>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..a738101660f7ebbdd7c7796d46df244f1e3f5f70 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h @@ -0,0 +1,318 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued Rank 2K update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Rank2K update operates on A=NxK, B=NxK, and C=NxN + assert(M==N); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x B^T (Symmetric) or A x B^H (Hermitian) + // complex conjugation on operandB (b_t) is function of blas3 computation + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_b.at(MatrixCoord(col, k_block))) : + tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(b_t); + + // complex conjugation is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + } + } + } + } + + /* HER2K need two epilogues to handle complex alpha value */ + if ( blas_mode == BlasMode::kHermitian ) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + tensor_d.at(coord) = convert_op(alpha * + ScalarType(accum[i][j]) + + beta * c); + } + } + } + + /* Zeoring out accum for second HERK */ + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // B x A^T (Symmetric) or B x A^H (Hermitian) + // complex conjugation on operandB (a_t) is function of blas3 computation + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))): + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType b_ik = ComputeType(b); + ComputeType a_jk = ComputeType(a_t); + + // complex conjugation here is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_ik = conj(b_ik); + } + // complex conjugation here is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_jk = conj(a_jk); + } + + accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]); + } + } + } + } + + ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ? + conj(alpha) : alpha; + ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ? + 1 : beta; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType d = (blas_mode == BlasMode::kHermitian) ? + tensor_d.at(coord) : tensor_c.at(coord); + + ScalarType tmp_d = convert_op( + alpha_hermitian * ScalarType(accum[i][j]) + + beta_hermitian * d); + + if (blas_mode == BlasMode::kHermitian && row == col ) { + tensor_d.at(coord) = real(tmp_d); + } else { + tensor_d.at(coord) = tmp_d; + } + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..1aad33fd643b60752bc0845e403cebc43ad7d047 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued Rank 2K update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Rank2K update operates on A=NxK, B=NxK, and C=NxN + assert(M==N); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x A^T (Symmetric) or A x A^H (Hermitian) + // complex conjugation on operandB (a_t) (function of blas3 computation) + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))) : + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(a_t); + + // complex conjugation (function of input layouts) + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation (function of input layouts) + if (transform_a == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + ScalarType tmp_d = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + + if (blas_mode == BlasMode::kHermitian && row == col ) { + tensor_d.at(coord) = real(tmp_d); + } else { + tensor_d.at(coord) = tmp_d; + } + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void RankKComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h new file mode 100644 index 0000000000000000000000000000000000000000..34f9648f25f8965f6730999b7763220c360683a8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h @@ -0,0 +1,285 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for SYMM update in host-side code. + + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert( + FillModeA == FillMode::kLower || + FillModeA == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp_w_diag = typename TrMatrixCompareOp::Type; + using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp_w_diag compare_op_1; + CompareOp_wo_diag compare_op_2; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a_1 = ElementA(); + ElementB b_1 = ElementB(); + ElementA a_2 = ElementA(); + ElementB b_2 = ElementB(); + + // A x B or B x A (with diagonal) + if (SideModeA == SideMode::kLeft) { + a_1 = (compare_op_1(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); + b_1 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a_1 = tensor_b.at(MatrixCoord(row, k_block)); + b_1 = (compare_op_1(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); + } + + ComputeType compute_a_1(cast_if_scalar(a_1)); + ComputeType compute_b_1(cast_if_scalar(b_1)); + + accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); + + // A^T x B or B x A^T (without diagonal) + if (SideModeA == SideMode::kLeft) { + a_2 = (compare_op_2(k_block, row)) ? + (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); + b_2 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a_2 = tensor_b.at(MatrixCoord(row, k_block)); + b_2 = (compare_op_2(col, k_block)) ? + tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); + } + + ComputeType compute_a_2(cast_if_scalar(a_2)); + ComputeType compute_b_2(cast_if_scalar(b_2)); + + accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_symm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Symm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Symm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..79e146f69b784a92ce61a093f410e93a66005cf8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued SYMM update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + BlasMode BlasMode_ = BlasMode::kSymmetric, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm_complex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static SideMode const kSideModeA = SideModeA; + static FillMode const kFillModeA = FillModeA; + static BlasMode const kBlasMode = BlasMode_; + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(kSideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert( + kFillModeA == FillMode::kLower || + kFillModeA == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp_w_diag = typename TrMatrixCompareOp::Type; + using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp_w_diag compare_op_1; + CompareOp_wo_diag compare_op_2; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) + { + ElementA a_1 = ElementA(); + ElementB b_1 = ElementB(); + ElementA a_2 = ElementA(); + ElementB b_2 = ElementB(); + + // A x B or B x A (with diagonal) + if (kSideModeA == SideMode::kLeft) { + a_1 = (compare_op_1(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); + b_1 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (kSideModeA == SideMode::kRight) { + a_1 = tensor_b.at(MatrixCoord(row, k_block)); + b_1 = (compare_op_1(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); + } + ComputeType compute_a_1 = ComputeType(a_1); + ComputeType compute_b_1 = ComputeType(b_1); + + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) { + compute_a_1 = real(compute_a_1); + } else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) { + compute_b_1 = real(compute_b_1); + } + + accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); + + // A^T x B or B x A^T (without diagonal) + if (kSideModeA == SideMode::kLeft) { + a_2 = (compare_op_2(k_block, row)) ? + (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); + b_2 = tensor_b.at(MatrixCoord(k_block, col)); + if (kBlasMode == BlasMode::kHermitian) + a_2 = conj(a_2); + } else if (kSideModeA == SideMode::kRight) { + a_2 = tensor_b.at(MatrixCoord(row, k_block)); + b_2 = (compare_op_2(col, k_block)) ? + tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); + if (kBlasMode == BlasMode::kHermitian) + b_2 = conj(b_2); + } + + ComputeType compute_a_2 = ComputeType(a_2); + ComputeType compute_b_2 = ComputeType(b_2); + + accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + ScalarType c = tensor_c.at(coord); + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric, + typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex +> +struct SymmComplex; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct SymmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm_complex>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for gaussian multiply-add +template +struct SymmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm_complex>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h new file mode 100644 index 0000000000000000000000000000000000000000..d6b85ca1baf65ba811b7c8b3a224ca90bbce1680 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h @@ -0,0 +1,616 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "cutlass/util/distribution.h" +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorGreatestErrorFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double result; + + /// Ctor + TensorGreatestErrorFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + result(0.0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + result = std::max(result, std::abs(double(lhs_) - double(rhs_))); + } + + /// Returns true if equal + operator double() const { + return result; + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorMREFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double sum; + uint64_t count; + static constexpr double epsilon = 1e-6; + + /// Ctor + TensorMREFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + sum(0.0), + count(0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + sum += std::abs(double(lhs_) - double(rhs_) / (double(rhs_) + epsilon)); + ++count; + } + + /// Returns true if equal + operator double() const { + return sum / double(count); + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorMSEFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double sum; + uint64_t count; + + /// Ctor + TensorMSEFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + sum(0.0), + count(0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + sum += std::pow((double(lhs_) - double(rhs_)), 2); + ++count; + } + + /// Returns true if equal + operator double() const { + return sum / double(count); + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + bool result; + + /// Ctor + TensorEqualsFunc(): result(true) { } + + /// Ctor + TensorEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), rhs(rhs_), result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (lhs_ != rhs_) { + result = false; + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorRelativelyEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + Element epsilon; + Element nonzero_floor; + bool result; + + /// Ctor + TensorRelativelyEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_, + Element epsilon_, + Element nonzero_floor_ + ) : + lhs(lhs_), + rhs(rhs_), + epsilon(epsilon_), + nonzero_floor(nonzero_floor_), + result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (!relatively_equal(lhs_, rhs_, epsilon, nonzero_floor)) { + result = false; + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the Mean Squared Error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorMSE( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorMSEFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the Mean Relative Error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorMRE( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorMREFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the greatest error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorGreatestError( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorGreatestErrorFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorView const &lhs, + TensorView const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc func(lhs, rhs, epsilon, nonzero_floor); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are NOT equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return true; + } + + detail::TensorEqualsFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return !bool(func); +} + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + return !TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorContainsFunc { + + // + // Data members + // + + TensorView view; + Element value; + bool contains; + Coord location; + + // + // Methods + // + + /// Ctor + TensorContainsFunc(): contains(false) { } + + /// Ctor + TensorContainsFunc( + TensorView const &view_, + Element value_ + ) : + view(view_), value(value_), contains(false) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + if (view.at(coord) == value) { + if (!contains) { + location = coord; + } + contains = true; + } + } + + /// Returns true if equal + operator bool() const { + return contains; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if a value is present in a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorContains( + TensorView const & view, + Element value) { + + detail::TensorContainsFunc func( + view, + value + ); + + TensorForEach( + view.extent(), + func + ); + + return bool(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns a pair containing a boolean of whether a value exists in a tensor and the location of +/// of the first occurrence. If the value is not contained in the tensor, the second element of the +/// pair is undefined. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +std::pair > TensorFind( + TensorView const & view, + Element value) { + + detail::TensorContainsFunc func( + view, + value + ); + + TensorForEach( + view.extent(), + func + ); + + return std::make_pair(bool(func), func.location); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp new file mode 100644 index 0000000000000000000000000000000000000000..27ef969b4ff2b6d8f3a53f3d1a3e5ec3e5203ec3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename TensorL, + typename TensorR +> +bool TensorEquals( + TensorL lhs, + TensorR rhs) { + + // Extents must be identical + if (cute::size(lhs) != cute::size(rhs)) { + return false; + } + + for (int64_t idx = 0; idx < cute::size(lhs); ++idx) { + if (lhs(idx) != rhs(idx)) { + return false; + } + } + + return true; +} + +/// Returns true if two tensor views are NOT equal. +template < + typename TensorL, + typename TensorR +> +bool TensorNotEquals( + TensorL lhs, + TensorR rhs) { + + return TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h new file mode 100644 index 0000000000000000000000000000000000000000..d2a43b1295c8ab18c7d649c79b0364b6d3e7c48c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h @@ -0,0 +1,256 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper to convert between types +template < + typename DstElement, + typename SrcElement +> +struct TrivialConvert { + + TrivialConvert() { } + + DstElement operator()(SrcElement src) const { + return DstElement(src); + } +}; + +/// Helper to conditionally copy between tensor views. +template < + typename DstElement, + typename DstLayout, + typename SrcElement, + typename SrcLayout, + typename F +> +struct TensorCopyIf { + + using DstTensorView = TensorView; + using SrcTensorView = TensorView; + + // + // Data members + // + + DstTensorView dst; + SrcTensorView src; + F convert; + + // + // Methods + // + + TensorCopyIf() { } + + TensorCopyIf( + DstTensorView const &dst_, + SrcTensorView const &src_, + F const &convert_): dst(dst_), src(src_), convert(convert_) {} + + /// Copies based on destination and source bounds + void operator()(Coord const &coord) { + if (dst.contains(coord) && src.contains(coord)) { + dst.at(coord) = convert(src.at(coord)); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorView src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + CopyIf copy_if(dst, src, transform); + + TensorForEach(dst.extent(), copy_if); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent +/// to avoid out of bounds accesses. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorRef src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + TensorView src_view(src, dst.extent()); + + CopyIf copy_if(dst, src_view, transform); + + TensorForEach(dst.extent(), copy_if); +} + +/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent +/// to avoid out of bounds accesses. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorRef dst, + TensorView src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + TensorView dst_view(dst, src.extent()); + + CopyIf copy_if(dst_view, src, transform); + + TensorForEach(src.extent(), copy_if); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout /// Source tensor's layout +> +void TensorCopy( + TensorView dst, + TensorView src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorRef src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout /// Source tensor's layout +> +void TensorCopy( + TensorRef dst, + TensorView src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h new file mode 100644 index 0000000000000000000000000000000000000000..5470df29358799f6d5e6628e8722f0e3dc05485f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h @@ -0,0 +1,341 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" + +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to apply a binary operator in place +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementD, + typename LayoutD, + typename BinaryFunc> +struct TensorFuncBinaryOp { + + // + // Data members + // + + /// View of left-hand-side tensor + TensorView view_d; + TensorRef view_a; + TensorRef view_b; + BinaryFunc func; + + // + // Methods + // + + /// Constructor + TensorFuncBinaryOp() { } + + /// Constructor + TensorFuncBinaryOp( + TensorView const & view_d_, + TensorRef const & view_a_, + TensorRef const & view_b_, + BinaryFunc func = BinaryFunc() + ): + view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { } + + /// Equality check + void operator()(Coord const &coord) const { + view_d.at(coord) = func( + ElementD(view_a.at(coord)), + ElementD(view_b.at(coord)) + ); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Adds two tensors and stores in the destination tensor: d = a + b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorAdd( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::plus + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Adds a tensor in place: d = d .+ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorAdd( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorAdd(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Subtracts two tensors and stores in the destination tensor: d = a - b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorSub( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference + ) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::minus + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Subtracts two tensors in place: d = d .- a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorSub( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference + ) { + + TensorSub(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Multiplies two tensors and stores in the destination tensor: d = a .* b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorMul( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::multiplies + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Multiplies tensors in place: d = d .* a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorMul( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorMul(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Divides two tensors and stores in the destination tensor: d = a ./ b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorDiv( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::divides + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Divides tensors in place: d = d ./ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorDiv( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorDiv(d, d, a); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Divides two tensors and stores in the destination tensor: d = a ./ b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorModulus( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::divides + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Divides tensors in place: d = d ./ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorModulus( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorDiv(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h new file mode 100644 index 0000000000000000000000000000000000000000..645902f7dd7b62bc98a479e4956dfb4b437d46a7 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -0,0 +1,1718 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include +#include +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/distribution.h" +#include "tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element value; + + // + // Methods + // + + TensorFillFunc( + TensorView const &view_ = TensorView(), + Element value_ = Element(0) + ): view(view_), value(value_) { } + + void operator()(Coord const & coord) const { + view.at(coord) = value; + } +}; + +/// Returns a pair of values of the Gaussian distribution generated by the Box Muller method +struct BoxMullerFunc { + + BoxMullerFunc() {} + + void operator()( + double* rnd, ///< Size-2 vector to be filled with random values + double mean = 0, ///< Mean of the Gaussian distribution + double stddev = 1, ///< Standard deviation of the Gaussian distribution + double pi = std::acos(-1)) const { + + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + rnd[0] = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd[1] = std::sqrt(-2 * std::log(u1)) * std::sin(2 * pi * u2); + rnd[0] = mean + stddev * rnd[0]; + rnd[1] = mean + stddev * rnd[1]; + } +}; +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorView dst, ///< destination tensor + Element val = Element(0)) { ///< value to uniformly fill it with + + detail::TensorFillFunc func(dst, val); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorViewPlanarComplex dst, ///< destination tensor + cutlass::complex val = cutlass::complex(0)) { ///< value to uniformly fill it with + + TensorFill(dst.view_real(), val.real()); + TensorFill(dst.view_imag(), val.imag()); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + } + else { + result = static_cast(0); + } + + // Note that exclude_zero = true will disable the bernoulli_result above by unsetting zeros + if (exclude_zero && result == Element(0)) { + if (rnd > 0) { + rnd += 1; + } else { + rnd -= 1; + } + result = Element(rnd); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + double rnd[2]; + detail::BoxMullerFunc func; + func(rnd, mean, stddev, pi); + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd[0] = double(std::llround(rnd[0] * double(1 << int_scale))); + rnd[1] = double(std::llround(rnd[1] * double(1 << int_scale))); + reals[0] = from_real(rnd[0] / double(1 << int_scale)); + reals[1] = from_real(rnd[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd[0]); + reals[1] = from_real(rnd[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + } + + // Note that this will invalidate the above else statement because it unsets zero elements + if (exclude_zero && + reals[0] == from_real(0.0) && + reals[1] == from_real(0.0)) { + + if (rnd[0] > 0.0) { + rnd[0] += 1.0; + } else { + rnd[0] -= 1.0; + } + reals[0] = from_real(rnd[0]); + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + double rnd1[2]; + double rnd2[2]; + detail::BoxMullerFunc func; + func(rnd1, mean, stddev, pi); + func(rnd2, mean, stddev, pi); + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd1[0] = double(std::llround(rnd1[0] * double(1 << int_scale))); + rnd1[1] = double(std::llround(rnd1[1] * double(1 << int_scale))); + rnd2[0] = double(std::llround(rnd2[0] * double(1 << int_scale))); + rnd2[1] = double(std::llround(rnd2[1] * double(1 << int_scale))); + + reals[0] = from_real(rnd1[0] / double(1 << int_scale)); + reals[1] = from_real(rnd1[1] / double(1 << int_scale)); + reals[2] = from_real(rnd2[0] / double(1 << int_scale)); + reals[3] = from_real(rnd2[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd1[0]); + reals[1] = from_real(rnd1[1]); + reals[2] = from_real(rnd2[0]); + reals[3] = from_real(rnd2[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + reals[2] = from_real(0); + reals[3] = from_real(0); + } + + // Note that this will invalidate the above else statement because it unsets zero elements + if (exclude_zero && + reals[0] == from_real(0) && + reals[1] == from_real(0) && + reals[2] == from_real(0) && + reals[3] == from_real(0)) { + + if (rnd1[0] > 0.0) { + rnd1[0] += 1.0; + } else { + rnd1[0] -= 1.0; + } + reals[0] = from_real(rnd1[0]); + } + + return Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillGaussianFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomGaussianFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillGaussianFunc( + TensorView view_ = TensorView(), + RandomGaussianFunc func_ = RandomGaussianFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + view.at(coord) = func(); + } +}; + +/// Computes a random Gaussian distribution for a rank-2 tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillSymmetricGaussianFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomGaussianFunc func; + cutlass::FillMode fill_mode; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillSymmetricGaussianFunc( + TensorView view_ = TensorView(), + RandomGaussianFunc func_ = RandomGaussianFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid + ): + view(view_), func(func_), fill_mode(fill_mode_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kLower && + coord[0] >= coord[1]) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + coord[0] <= coord[1]) { + view.at(coord) = func(); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of + /// data. + bool exclude_zero = false) { ///< Exclude zeros from tensor init. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz, exclude_zero); + + detail::TensorFillGaussianFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of + /// data. + bool exclude_zero = false) { ///< Exclude zeros from tensor init. + + TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits, pnz); + TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits, pnz); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSymmetricRandomGaussian( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); + + detail::TensorFillSymmetricGaussianFunc func( + dst, + random_func, + fill_mode + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values of a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of + /// data. + + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); + + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) + , bernoulli_rnd{static_cast(seed_)} + , bernoulli_dist(pnan_) + , exclude_zero(exclude_zero_) + { + std::srand((unsigned)seed); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero) { + min = (min == 0.0) ? min + 1: min; + range = (max == 0.0) ? range - 1: range; + } + } + + + /// Compute random value and update RNG state + Element operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + if (exclude_zero && result == Element(0)) { + if (rnd > 0.0) { + rnd = std::min(min + range, rnd + 1.0); + } else { + rnd = std::max(min, rnd - 1.0); + } + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) + , bernoulli_rnd{static_cast(seed_)} + , bernoulli_dist(pnan_) + , exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero) { + min = (min == 0.0) ? min + 1: min; + range = (max == 0.0) ? range - 1: range; + } + } + + + /// Compute random value and update RNG state + complex operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + + if (exclude_zero && + i == 0 && + reals[0] == from_real(0.0)) { + + if (rnd > 0.0) { + rnd = std::min(min + range, rnd + 1.0); + } else { + rnd = std::max(min, rnd - 1.0); + } + reals[0] = from_real(Real(rnd)); + } + + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_), + bernoulli_rnd{static_cast(seed_)}, + bernoulli_dist(pnan_) + { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +/// Computes a random uniform distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + + view.at(coord) = func(); + } +}; + +/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a uniform distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillSymmetricRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + cutlass::FillMode fill_mode; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillSymmetricRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid + ): + view(view_), func(func_), fill_mode(fill_mode_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kLower && + coord[0] >= coord[1]) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + coord[0] <= coord[1]) { + view.at(coord) = func(); + } + } +}; + +/// Computes a random Uniform distribution and pads diagonal with zeros +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillPadDiagonalRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + cutlass::FillMode fill_mode; + int alignment; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillPadDiagonalRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid, + int alignment_ = 1 + ): + view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + (fill_mode == cutlass::FillMode::kLower) && + (coord[0] >= coord[1]) || + ((coord[1] - coord[0]) >= alignment)) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + (coord[0] <= coord[1]) || + ((coord[0] - coord[1]) >= alignment)) { + view.at(coord) = func(); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values of a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + bool exclude_zero = false) { ///< Exclude zero from tensor init + detail::RandomUniformFunc random_func(seed, max, min, bits, pnan, exclude_zero); + + detail::TensorFillRandomUniformFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values of a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + bool exclude_zero = false) { ///< Exclude zero from tensor init + + TensorFillRandomUniform(dst.view_real(), seed, max, min, bits, pnan, exclude_zero); + TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits, pnan, exclude_zero); +} + + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView, Layout> dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc> random_func(seed, max, min, bits); + + detail::TensorFillRandomUniformFunc, Layout> func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSymmetricRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + detail::TensorFillSymmetricRandomUniformFunc func( + dst, + random_func, + fill_mode + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillPadDiagonalRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + int alignment = 1 +) { + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + detail::TensorFillPadDiagonalRandomUniformFunc func( + dst, + random_func, + fill_mode, + alignment + ); + + TensorForEach( + dst.extent(), + func + ); +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element ///< Element type +> +void BlockFill( + Element *ptr, + size_t capacity, + Element val + ) { + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = val; + } +} + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0) { ///< Percentage of NaN elements. + detail::RandomUniformFunc random_func(seed, max, min, bits, pnan); + + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillDiagonalFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element diag; + Element other; + + // + // Methods + // + + TensorFillDiagonalFunc( + TensorView const &view_ = TensorView(), + Element diag_ = Element(1), + Element other_ = Element(0) + ): + view(view_), diag(diag_), other(other_) { } + + void operator()(Coord const & coord) const { + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + view.at(coord) = (is_diag ? diag : other); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor everywhere with a unique value for its diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillDiagonal( + TensorView dst, ///< destination tensor + Element diag = Element(1), ///< value to write in the diagonal + Element other = Element(0)) { ///< value to write off the diagonal + + detail::TensorFillDiagonalFunc func( + dst, + diag, + other + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to fill a tensor's diagonal with 1 and 0 everywhere else. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillIdentity( + TensorView dst) { ///< destination tensor + + TensorFillDiagonal(dst, Element(1), Element(0)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateDiagonal( + TensorView dst, ///< destination tensor + Element val = Element(1)) { + + typename Layout::Index extent = dst.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + dst.at(coord) = val; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateOffDiagonalFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element other; + + // + // Methods + // + + TensorUpdateOffDiagonalFunc( + TensorView const &view_ = TensorView(), + Element other_ = Element(0) + ): + view(view_), other(other_) { } + + void operator()(Coord const & coord) const { + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (!is_diag) { + view.at(coord) = other; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateOffDiagonal( + TensorView dst, ///< destination tensor + Element other = Element(1)) { + + detail::TensorUpdateOffDiagonalFunc func( + dst, + other + ); + + TensorForEach( + dst.extent(), + func + ); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillLinearFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Array v; + Element s; + + // + // Methods + // + + TensorFillLinearFunc() { } + + /// Constructs functor + TensorFillLinearFunc( + TensorView const &view_, + Array const & v_, + Element s_ = Element(0) + ): + view(view_), v(v_), s(s_) { } + + /// Updates the tensor + void operator()(Coord const & coord) const { + + Element sum(s); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + sum += Element(coord[i]) * v[i]; + } + + view.at(coord) = sum; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillLinear( + TensorView dst, ///< destination tensor + Array const & v, + Element s = Element(0)) { + + detail::TensorFillLinearFunc func( + dst, + v, + s + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSequential( + TensorView dst, ///< destination tensor + Element s = Element(0)) { + + Array stride; + + stride[0] = Element(1); + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]); + } + + TensorFillLinear(dst, stride, s); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + bool exclude_zero = false ///< If true, excludes 0. + /// Note that setting this flag will result in more 1's, + /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. +) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + dist.gaussian.mean, + dist.gaussian.stddev, + dist.int_scale, + dist.gaussian.pnz, + exclude_zero); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + dist.uniform.max, + dist.uniform.min, + dist.int_scale, + dist.uniform.pnan, + exclude_zero); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = s; + + s = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = Element(s); + + s = int64_t(s + v) % mod; + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillRandom( + Element *ptr, + size_t capacity, + uint64_t seed, + Distribution dist) { + + if (dist.kind == Distribution::Gaussian) { + BlockFillRandomGaussian( + ptr, + capacity, + seed, + dist.gaussian.mean, + dist.gaussian.stddev, + dist.int_scale, + dist.gaussian.pnz); + } + else if (dist.kind == Distribution::Uniform) { + BlockFillRandomUniform( + ptr, + capacity, + seed, + dist.uniform.max, + dist.uniform.min, + dist.int_scale, + dist.uniform.pnan); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomSparseMetaFunc { + + uint64_t seed; + int range; + int MetaSizeInBits; + + // + // Methods + // + + RandomSparseMetaFunc( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), MetaSizeInBits(MetaSizeInBits_) { + std::srand((unsigned)seed); + if (MetaSizeInBits_ == 2) { + range = 6; + } + else if (MetaSizeInBits_ == 4) { + range = 2; + } + else { + throw std::invalid_argument("Invalid MetaSizeInBits"); + } + } + + /// Compute random value and update RNG state + Element operator()() const { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + int rnd = std::rand() % range; + Element meta = MetaArray[rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random sparse meta +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomSparseMetaFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillRandomSparseMetaFunc( + TensorView view_ = TensorView(), + RandomSparseMetaFunc func_ = RandomSparseMetaFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + + view.at(coord) = func(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4 bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + detail::TensorFillRandomSparseMetaFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a ell block index matrix with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomEllIdx( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int rows, int ell_cols, int cols) { ///< dimension of the matrix + + std::srand((unsigned)seed); + + for (int i = 0; i < rows; ++i) { + int col_idx = std::rand() % cols; + + for (int j = 0; j < ell_cols; ++j) { + dst.at({i, j}) = col_idx; + + if (col_idx != -1) { + if (col_idx == (cols - 1)) { + col_idx = -1; + } else { + col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1; + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies a diagonal in from host memory without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalIn( + TensorView dst, ///< destination tensor + Element const *ptr) { ///< dense buffer of elements + + typename Layout::Index extent = dst.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + dst.at(coord) = ReferenceFactory::get(ptr, i); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies the diagonal of a tensor into a dense buffer in host memory. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalOut( + Element *ptr, ///< dense buffer of elements + TensorView src) { ///< source tensor + + typename Layout::Index extent = src.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + ReferenceFactory::get(ptr, i) = src.at(coord); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1b3df239a1b9d69fc12e7ec4be2de6f87b3a0e3c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp @@ -0,0 +1,432 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Uniform and procedural tensor fills +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a scalar element +template +void TensorFill(Tensor dst, typename Tensor::value_type element) { + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = element; + } +} + +/// Fills a tensor with the contents of its layout +template +void TensorFillSequential(Tensor dst) { + + auto layout = dst.layout(); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = layout(idx); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random uniform values +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Element operator()() const { + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template ///< Tensor object +void TensorFillRandomUniform( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random Gaussian +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1 + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Tensor +> +void TensorFillRandomGaussian( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = static_cast(int32_t(int64_t(s + v) % mod)); + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..bcb1af995805e3fbcbdbf398ce7191ea2f0dbe8d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Index of the active rank + static int const kActiveRank = Rank - RankRemaining - 1; + + /// Constructor for general rank + TensorForEachHelper( + Func &func, + Coord const &extent, + Coord &coord) { + + for (int i = 0; i < extent.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + TensorForEachHelper(func, extent, coord); + } + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Index of the active rank + static int const kActiveRank = Rank - 1; + + /// Constructor for fastest changing rank + TensorForEachHelper( + Func &func, + Coord const &extent, + Coord &coord) { + + for (int i = 0; i < extent.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor +template < + typename Func, ///< function applied to each point in a tensor's index space + int Rank> ///< rank of index space +void TensorForEach(Coord extent, Func & func) { + Coord coord; + detail::TensorForEachHelper(func, extent, coord); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor and calls a C++ lambda +template < + typename Func, ///< function applied to each point in a tensor's index space + int Rank> ///< rank of index space +void TensorForEachLambda(Coord extent, Func func) { + Coord coord; + detail::TensorForEachHelper(func, extent, coord); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockForEach { + + /// Constructor performs the operation. + BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params = typename Func::Params()) { + + Func func(params); + + for (size_t index = 0; index < capacity; ++index) { + ptr[index] = func(); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..d44dda1f5472f13b7212f7e2e4020e254ff92f88 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" + +// The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions. + +#include "cutlass/util/reference/host/tensor_reduce.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..887c568059a90f749fc0ac75dd211ce77085a5a9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/util/reference/detail/linear_to_coordinate.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < int64_t(view.size()); ++idx) { + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + Element x = view.at(coord); + identity = reduce(identity, transform(x)); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Tensor extents must match."); + } + + for (int64_t idx = 0; idx < int64_t(view_A.size()); ++idx) { + + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + Element a = view_A.at(coord); + Element b = view_B.at(coord); + identity = reduce(identity, transform(a, b)); + } + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea711466df86703aae1702605a928754c9f4e944 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tensor reductions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Tensor, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + Tensor view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < cute::size(view); ++idx) { + identity = reduce(identity, transform(view(idx))); + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename TensorA, + typename TensorB, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorA view_A, + TensorB view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (cute::size(view_A) != cute::size(view_B)) { + throw std::runtime_error("Tensor sizes must match."); + } + + for (int64_t idx = 0; idx < cute::size(view_A); ++idx) { + identity = reduce(identity, transform(view_A(idx), view_B(idx))); + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSum( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSumSq( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Tensor, + typename ComputeType = double +> +ComputeType TensorNorm( + Tensor view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h new file mode 100644 index 0000000000000000000000000000000000000000..09b1aff9c0ea9922af46c928a3dd61595be2e4cd --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for TRMM in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_trmm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper + , "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = ElementA(); + ElementB b = ElementB(); + + if (SideModeA == SideMode::kLeft) { + a = (compare_op(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); + if (row == k_block && DiagTypeA == DiagType::kUnit) { + a = ElementA(1); + } + b = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a = tensor_b.at(MatrixCoord(row, k_block)); + b = (compare_op(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); + if (k_block == col && DiagTypeA == DiagType::kUnit) { + b = ElementA(1); + } + } + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j])); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Trmm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Trmm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..e8db2a4deaf8608882595d68e611f8ae79e134e8 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued TRMM in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + ComplexTransform TransformA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + ComplexTransform TransformB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_trmm_complex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper + , "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = ElementA(); + ElementB b = ElementB(); + + if (SideModeA == SideMode::kLeft) { + a = (compare_op(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); + if (row == k_block && DiagTypeA == DiagType::kUnit) { + a = ElementA(1); + } + b = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a = tensor_b.at(MatrixCoord(row, k_block)); + b = (compare_op(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); + if (k_block == col && DiagTypeA == DiagType::kUnit) { + b = ElementA(1); + } + } + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + // Conjugate, and hence hermitian, is only allowed for the triangular matrix + if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j])); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + ComplexTransform TransformA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + ComplexTransform TransformB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex +> +struct TrmmComplex; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct TrmmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm_complex>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for gaussian multiply-add +template +struct TrmmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm_complex>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h new file mode 100644 index 0000000000000000000000000000000000000000..0ce1d8a65fdd66ace69f91525b678dd6ad132d24 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h @@ -0,0 +1,270 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +#pragma once + +#include "cutlass/core_io.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/complex.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorView_WriteLeastSignificantRank( + std::ostream& out, + TensorView const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + out << ScalarIO(view.at(coord)); + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorView_WriteRank( + std::ostream& out, + TensorView const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by "\n" + if (idx) { + out << ",\n"; + } + TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + if (idx) { + out << ",\n\n"; + } + TensorView_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + + complex x = view.at(coord); + out << x; + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by ";\n" + if (idx) { + out << ";\n"; + } + TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + if (idx) { + out << "\n"; + } + TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorView const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorView_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorView const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..5dfbfe274dec368cfac291a1c78ece6ffb203c72 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Type traits for common CUDA types +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/numeric_types.h" +#include "cutlass/complex.h" + +namespace cutlass { +struct half_t; + +template +struct TypeTraits { + typedef T host_type; + typedef T device_type; + static inline T remove_negative_zero(T x) { return x; } + static inline T to_print(T x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef int8_t host_type; + typedef int8_t device_type; + typedef int8_t integer_type; + typedef uint8_t unsigned_type; + static inline int8_t remove_negative_zero(int8_t x) { return x; } + static inline int to_print(int8_t x) { return (int)x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef uint8_t host_type; + typedef uint8_t device_type; + typedef uint8_t integer_type; + typedef uint8_t unsigned_type; + static inline uint8_t remove_negative_zero(uint8_t x) { return x; } + static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32I; + typedef int host_type; + typedef int device_type; + typedef int32_t integer_type; + typedef uint32_t unsigned_type; + static inline int32_t remove_negative_zero(int32_t x) { return x; } + static inline int to_print(int x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32I; + typedef unsigned host_type; + typedef unsigned device_type; + typedef uint32_t integer_type; + typedef uint32_t unsigned_type; + static inline uint32_t remove_negative_zero(uint32_t x) { return x; } + static inline uint32_t to_print(uint32_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef int64_t host_type; + typedef int64_t device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + static inline int64_t remove_negative_zero(int64_t x) { return x; } + static inline int64_t to_print(int64_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef uint64_t host_type; + typedef uint64_t device_type; + typedef uint64_t integer_type; + typedef uint64_t unsigned_type; + static inline uint64_t remove_negative_zero(uint64_t x) { return x; } + static inline uint64_t to_print(uint64_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_16F; + typedef half_t host_type; + typedef half_t device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline half_t remove_negative_zero(half_t x) { + return (x.raw() == 0x8000 ? half_t::bitcast(0) : x); + } + static inline half_t to_print(half_t x) { return x; } + static inline device_type to_device(half_t x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32F; + typedef float host_type; + typedef float device_type; + typedef int32_t integer_type; + typedef uint32_t unsigned_type; + static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; } + static inline float to_print(float x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_64F; + typedef double host_type; + typedef double device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; } + static inline double to_print(double x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Complex types +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_16F; + typedef complex host_type; + typedef complex device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_16F; + typedef complex host_type; + typedef complex device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0_hf ? 0_hf : real(x), + imag(x) == -0_hf ? 0_hf : imag(x) + ); + } + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + + static cudaDataType_t const cublas_type = CUDA_C_32F; + typedef complex host_type; + typedef complex device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0.f ? 0.f : real(x), + imag(x) == -0.f ? 0.f : imag(x) + ); + } + + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_64F; + typedef complex host_type; + typedef complex device_type; + struct integer_type { int64_t real, imag; }; + struct unsigned_type { uint64_t real, imag; }; + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0.0 ? 0.0 : real(x), + imag(x) == -0.0 ? 0.0 : imag(x) + ); + } + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py new file mode 100644 index 0000000000000000000000000000000000000000..6541ce1b26722ff1f0dba0b4c034067a62f9b96d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py @@ -0,0 +1,356 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + + +""" +Given a set of test files to be included in a CMake target, this script extracts +the TEST definitions from each file, writes them into new files, and prints the names +of the new files so that they can be processed as part of a new CMake target. + +For example, given a set of --src_files test_a.cu test_b.cu containing 3 and 2 TEST +definitions, respectively, this script would produce: + test_a_000.cu + test_a_001.cu + test_a_002.cu + test_b_000.cu + test_b_001.cu + +The splitting follows a fairly rudimentary algorithm that does not support all valid C++ programs. +We walk through a given input test file line by line. Any lines that are not within a TEST definition is added to a running +"filler" text. When a TEST definition is encountered, the current filler text becomes the prefix +for that test. All subsequent lines are considered to be part of the TEST definition until the +number of starting function braces ('{') match the number of closing function braces ('}'). When +these counts are equal, the TEST definition is considered to be completed. At this point, we return +to adding lines to the "filler" text until a new TEST definition is encountered. Any "filler" text +following a TEST definition is added to the suffix of that TEST definition (this is useful for finishing +off #if statements, as is common in unit tests.). + +A state machine illustrating this algorithm at a high level is provided in the source below. + +Example: Suppose an input test `test.cu` has the following source: + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + TEST(SM90_a, 256x128x64_2x2x1) { + std::cout << "Test #1" << std::endl; + } + + // Test #2 + TEST(SM90_b, 256x128x64_1x1x1) { + std::cout << "Test #2" << std::endl; + } + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +The contents of the two resulting test files will be: + $ cat test_000.cu + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + TEST(SM90_a, 256x128x64_2x2x1) { + std::cout << "Test #1" << std::endl; + } + + // Test #2 + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + $ cat test_001.cu + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + + // Test #2 + TEST(SM90_b, 256x128x64_1x1x1) { + std::cout << "Test #2" << std::endl; + } + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +Notice that each of test_000.cu and test_001.cu contain comments that appear outside +the TEST definitions not included in each file. This is by design, as these +would be considered "filler" text. + +As expected, some cases can't be handled. Below is a non-exhaustive list: + 1. New TEST following the closing '}' of a TEST case on the same line: + TEST(x, y) { + // Do stuff + } TEST(a, b) { + + In this case, "TEST(a, b) {" will be ignored + + 2. Preprocessor macros that occur midway through a test case and extend + beyond the conclusion of a testcase + + Example: + TEST(a, b) { + // Do stuff + #if X + // Do more stuff + } + #else + // Do other stuff + } + #endif +""" + + +import argparse +import enum +import os + + +parser = argparse.ArgumentParser() +parser.add_argument("cmake_target", type=str, + help="Name of the CMake target being generated.") +parser.add_argument("src_dir", type=str, + help="Path to the directory containing test files.") +parser.add_argument("--src_files", nargs='+', + help="Files containing TEST instances to split.") +parser.add_argument("--max_tests_per_file", type=int, default=1, + help="Maximum number of TEST instances per file.") +parser.add_argument("--dst_dir", type=str, + help="Path to the directory to which to write new test files. If not set, uses src_dir.") +args = parser.parse_args() + + +if args.dst_dir == None: + args.dst_dir = args.src_dir + + +class Testcase: + """ + Lightweight tracker of test-case processing status + """ + def __init__(self, prefix_text): + # Any text that preceded the TEST definition that was + # not part of another TEST definition + self.prefix = prefix_text + + # Any text within the TEST definition + self.test = "" + + # Any text that follows the completion of the TEST definition + # and is not included in other TEST definitions + self.suffix = "" + + # Whether the test's definition has concluded + self.completed = False + + # Current balance of opening and closing curly brackets in + # the TEST definition. '{' increments the count and '}' decrements it. + # A value of 0 (when self.completed == False) indicates that the test + # has completed. + self.curly_bracket_balance = 0 + + +class ParseState(enum.Enum): + """ + State machine for processing. + Transitions occur on each line encountered in the soruce file + + + Line does not contain 'TEST(' + +----+ + | | + | v 'TEST(' + +--------+ encountered +--------------------------+ + ------>| Filler | -----------------------> | TestDeclaredWaitingStart | + +--------+ +--------------------------+ + ^ | + Number of '{' | | First '{' encountered + equals number of | +--------+ | + '}' encountered +-----------| InTest | <------------------+ + +--------+ + | ^ + | | + +----+ + Number of '{' encountered + exceeds number of '}' encountered + """ + + + # Any text that is not part of a TEST case + Filler = 0 + + # Processing text within the first { of the TEST case + # and before the en of the final } of the TEST case + InTest = 1 + + # Processing text from the start of the TEST definition + # but before the first {. This could occur if the opening { + # occurs on a separate line than the TEST definition. + TestDeclaredWaitingStart = 2 + + +cmake_src_list = [] +for filename in args.src_files: + if '.' not in filename: + # Add any non-filename arguments to the command list by default + cmake_src_list.append(filename) + continue + + if '/' in filename: + raise Exception( + f"Source files passed to {__file__} must be within the same directory " + "as the CMakeLists defining the target using the files. " + f"Provided path {filename} is in a different directory.") + + full_filename = os.path.join(args.src_dir, filename) + with open(full_filename, 'r') as infile: + lines = infile.readlines() + + # Find the number of instances of "TEST(" + ntest = sum([1 for line in lines if "TEST(" in line]) + + if ntest <= args.max_tests_per_file: + # File contains fewer than max_tests_per_file TEST instances. It does + # not need to be split + cmake_src_list.append(filename) + continue + + # Current state of the parsing state machine. We start with filler text + state = ParseState.Filler + + # List of individual TESTs found + tests = [] + + # Ongoing text that is not included in a TEST definition. This will serve + # as the prefix for any yet-to-be encountered TEST definitions. + filler_text = "" + + def add_filler_text(text): + global filler_text + # Add new text to the ongoing filler text and to the suffixes of + # any completed tests + filler_text += text + for i in range(len(tests)): + if tests[i].completed: + tests[i].suffix += text + + for line in lines: + if state == ParseState.Filler: + # We are not currently within a TEST definition. + + if 'TEST(' in line: + # We have encountered a new TEST( case. Any text preceding this + # must be added to the filler text (e.g., if we have a line of the form: + # "static constexpr int Val = 4; TEST(blah) {" + # then "static constexpr int Val = 4;" needs to be included in filler + # text, as it could be used by subsequent tests.) + splits = line.split('TEST') + + # There should not be more than one TEST definition on a given line + assert len(splits) <= 2 + + if len(splits) > 1: + if not splits[0].isspace(): + # Only add text to filler if there are non-whitespace charcters + # preceding the TEST definition in the line + filler_text += splits[0] + + # The new line is just the TEST-related line + line = 'TEST' + splits[-1] + + # Add tests and transtion to TestDeclaredWaitingStart state. + # Do not add the line to the test text of the new test case; this + # will be done in either the TestDeclaredWaitingStart state processing + # below or in the InTest state processing below. + tests.append(Testcase(filler_text)) + state = ParseState.TestDeclaredWaitingStart + else: + # Any remaining filler text is added to the running filler_text + # which will be used as the prefix for any new tests, and to the + # suffix of any completed tests + add_filler_text(line) + + if state == ParseState.TestDeclaredWaitingStart: + # We have seen a TEST definition but have not yet seen its opening {. + + if '{' in line: + # The first curly bracket for the TEST definition has been found. + # Advance to state InTests. Do not add the line to the test's text + # or change the curly-brace balance of the test; these will be done + # when processing the state == ParseState.InTest condition below. + state = ParseState.InTest + else: + tests[-1].test += line + + if state == ParseState.InTest: + # We are currently within a TEST definition. + # Process lines character-by-character looking for opening and closing + # braces. If we reach parity between opening and closing braces, the + # test is considered done. + filler_text_to_add = "" + for char in line: + if not tests[-1].completed: + tests[-1].test += char + if char == '{': + tests[-1].curly_bracket_balance += 1 + elif char == '}': + tests[-1].curly_bracket_balance -= 1 + if tests[-1].curly_bracket_balance == 0: + tests[-1].completed = True + else: + filler_text_to_add += char + + if filler_text_to_add != "" and (not filler_text_to_add.isspace() or '\n' in filler_text_to_add): + add_filler_text('\n' + filler_text_to_add) + + if tests[-1].completed: + state = ParseState.Filler + + # Write out the new files for tests + filename_prefix, filename_suffix = filename.split('.') + for i, test in enumerate(tests): + assert test.completed + new_filename = filename_prefix + '_' + str(i).zfill(3) + '.' + filename_suffix + full_new_filename = os.path.join(args.dst_dir, new_filename) + + # Replace any '\' with '/'. CMake doesn't like '\'. + full_new_filename = full_new_filename.replace('\\', '/') + + with open(full_new_filename, 'w') as outfile: + outfile.write(test.prefix + test.test + test.suffix) + cmake_src_list.append(full_new_filename) + + +for cmake_file in cmake_src_list: + print(cmake_file) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/metadata.json b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4899badb63d45293425e2164944268b6058af95d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json @@ -0,0 +1,11 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/testing/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13a9d78dea58a6492183f9ddc50f1510a679cbe6 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/testing/__init__.py @@ -0,0 +1,4 @@ +from . import bench, numeric, utils +from .bench import * +from .numeric import * +from .utils import * diff --git a/build/torch29-cxx11-cu129-aarch64-linux/testing/bench.py b/build/torch29-cxx11-cu129-aarch64-linux/testing/bench.py new file mode 100644 index 0000000000000000000000000000000000000000..2c752da2d3bb0aba7e03ef1921428432b396917a --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/testing/bench.py @@ -0,0 +1,137 @@ +import os +import sys +import torch + + +def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache.zero_() + + # Warmup + for _ in range(num_warmups): + fn() + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: + x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + x @ y + + # Testing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_tests / 1e3 + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, + with_multiple_kernels: bool = False): + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + + # Skip profiling + # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer + if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)): + return (1, ) * len(kernel_names) if is_tuple else 1 + + # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) + with profiler: + for i in range(2): + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + fn() + profiler.step() + + # Parse the profiling table + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num if total_num > 0 else 0) + + return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/testing/numeric.py b/build/torch29-cxx11-cu129-aarch64-linux/testing/numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..a42c4318db47593c47a4ea89fbdbcb1ffb5cd30e --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/testing/numeric.py @@ -0,0 +1,21 @@ +import torch +from typing import Iterable + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + if denominator == 0: # Which means that all elements in x and y are 0 + return 0.0 + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(*tensors): + total = 0 + for t in tensors: + if isinstance(t, (tuple, list)): + total += count_bytes(*t) + elif t is not None: + total += t.numel() * t.element_size() + return total diff --git a/build/torch29-cxx11-cu129-aarch64-linux/testing/utils.py b/build/torch29-cxx11-cu129-aarch64-linux/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d202d4192ed385f986ac5cc216acc69378d8ea9 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/testing/utils.py @@ -0,0 +1,38 @@ +import functools +import os +import torch +from typing import Callable + +def get_arch_major() -> int: + major, minor = torch.cuda.get_device_capability() + return major + + +def test_filter(condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + func(*args, **kwargs) + else: + print(f'{func.__name__}:') + print(f' > Filtered by {condition}') + print() + return wrapper + return decorator + + +def ignore_env(name: str, condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + saved = os.environ.pop(name, None) + func(*args, **kwargs) + if saved is not None: + os.environ[name] = saved + else: + func(*args, **kwargs) + + return wrapper + return decorator diff --git a/build/torch29-cxx11-cu129-aarch64-linux/utils/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f859a20726fcc0ea32c54ed8df37b19b3960a4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/utils/__init__.py @@ -0,0 +1,3 @@ +from . import math, layout +from .layout import * +from .math import * diff --git a/build/torch29-cxx11-cu129-aarch64-linux/utils/layout.py b/build/torch29-cxx11-cu129-aarch64-linux/utils/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..a6bc29d9aaae296a83b8c3546b832a083ade6b28 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/utils/layout.py @@ -0,0 +1,25 @@ +from .._ops import ops + + +def get_mk_alignment_for_contiguous_layout(): + return ops.get_mk_alignment_for_contiguous_layout() + + +def get_tma_aligned_size(mn: int, element_size: int): + return ops.get_tma_aligned_size(mn, element_size).item() + + +def get_mn_major_tma_aligned_tensor(sf): + return ops.get_mn_major_tma_aligned_tensor(sf) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): + return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): + return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks) + + +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/build/torch29-cxx11-cu129-aarch64-linux/utils/math.py b/build/torch29-cxx11-cu129-aarch64-linux/utils/math.py new file mode 100644 index 0000000000000000000000000000000000000000..c65026e54b87faf34b498d14d3f81a94759615f4 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/utils/math.py @@ -0,0 +1,107 @@ +import torch +from typing import Tuple + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + padded_n = align(n, gran_k) + x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf + + +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % gran_k == 0 + m, n = x.shape + x_view = x.view(-1, gran_k, n) + x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + ax = x.abs().clamp_max(6.0) + # {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], + device=x.device, dtype=ax.dtype) + idx = torch.bucketize(ax, boundaries) + code = idx.to(torch.uint8) + sign = (x < 0) & (idx != 0) + code = code | (sign.to(torch.uint8) << 3) + return code # uint8, 0..15 + + +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + assert n % 2 == 0 + padded_n = align(n, gran_k) + x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = x_amax / 6.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = x_view * (1.0 / sf.unsqueeze(2)) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) + codes2 = codes.view(m, padded_n // 2, 2) + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 + return packed[:, :n // 2].contiguous(), sf + + +def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.uint8 + assert a.dim() == 2 + m, n2 = a.shape + n = n2 * 2 + assert (m % 2) == 0 + lo = a & 0x0F + hi = (a >> 4) & 0x0F + codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) + codes[:, 0::2], codes[:, 1::2] = lo, hi + codes_t = codes.transpose(0, 1).contiguous() + codes2 = codes_t.view(n, m // 2, 2) + out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) + return out.contiguous() \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-aarch64-linux/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9cd2069a1bc6bcd76e84fbe093d7f16d394efa05 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,684 @@ +import os +import subprocess +import torch + +# Import the compiled extension +from ._ops import ops +from . import utils + +__version__ = "2.3.0" + + +# Runtime + + +def set_num_sms(num_sms: int): + ops.set_num_sms(num_sms) + + +def get_num_sms() -> int: + return ops.get_num_sms() + + +def set_tc_util(tc_util: int): + ops.set_tc_util(tc_util) + + +def get_tc_util() -> int: + return ops.get_tc_util() + + +def get_mk_alignment_for_contiguous_layout() -> int: + return ops.get_mk_alignment_for_contiguous_layout() + + +# Layout utilities + + +def get_tma_aligned_size(mn: int, element_size: int) -> int: + return ops.get_tma_aligned_size(mn, element_size).item() + + +def get_mn_major_tma_aligned_tensor(sf): + return ops.get_mn_major_tma_aligned_tensor(sf) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): + return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): + ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu") + return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks_int + ) + + +def transform_sf_into_required_layout( + sf, + mn, + k, + recipe=None, + recipe_ab=None, + num_groups=None, + is_sfa=False, + disable_ue8m0_cast=False, +): + has_recipe = recipe is not None + r0, r1, r2 = recipe if has_recipe else (0, 0, 0) + has_recipe_ab = recipe_ab is not None + rab0, rab1 = recipe_ab if has_recipe_ab else (0, 0) + has_ng = num_groups is not None + ng = num_groups if has_ng else 0 + return ops.transform_sf_into_required_layout( + sf, + mn, + k, + r0, + r1, + r2, + has_recipe, + rab0, + rab1, + has_recipe_ab, + ng, + has_ng, + is_sfa, + disable_ue8m0_cast, + ) + + +# Aliases for contiguous layout alignment +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout + + +# Helper to flatten recipe args + + +def _flatten_recipe(recipe, recipe_a=None, recipe_b=None): + has_recipe = recipe is not None + r0, r1, r2 = recipe if has_recipe else (0, 0, 0) + has_ra = recipe_a is not None + ra0, ra1 = recipe_a if has_ra else (0, 0) + has_rb = recipe_b is not None + rb0, rb1 = recipe_b if has_rb else (0, 0) + return r0, r1, r2, has_recipe, ra0, ra1, has_ra, rb0, rb1, has_rb + + +# FP8/FP4 GEMM ops + + +def fp8_fp4_gemm_nt( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_nt( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_fp4_gemm_nn( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_nn( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_fp4_gemm_tn( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="mn", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_tn( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_fp4_gemm_tt( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="mn", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_tt( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +# FP8 aliases (same as FP8/FP4) +fp8_gemm_nt = fp8_fp4_gemm_nt +fp8_gemm_nn = fp8_fp4_gemm_nn +fp8_gemm_tn = fp8_fp4_gemm_tn +fp8_gemm_tt = fp8_fp4_gemm_tt + + +# M-grouped FP8/FP4 GEMM ops + + +def m_grouped_fp8_fp4_gemm_nt_contiguous( + a, + b, + d, + grouped_layout, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, + use_psum_layout=False, + expected_m_for_psum_layout=None, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + has_em = expected_m_for_psum_layout is not None + em = expected_m_for_psum_layout if has_em else 0 + ops.m_grouped_fp8_fp4_gemm_nt_contiguous( + a_data, + a_sf, + b_data, + b_sf, + d, + grouped_layout, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + use_psum_layout, + em, + has_em, + ) + + +def m_grouped_fp8_fp4_gemm_nn_contiguous( + a, + b, + d, + grouped_layout, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, + use_psum_layout=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.m_grouped_fp8_fp4_gemm_nn_contiguous( + a_data, + a_sf, + b_data, + b_sf, + d, + grouped_layout, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + use_psum_layout, + ) + + +def m_grouped_fp8_fp4_gemm_nt_masked( + a, + b, + d, + masked_m, + expected_m, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.m_grouped_fp8_fp4_gemm_nt_masked( + a_data, + a_sf, + b_data, + b_sf, + d, + masked_m, + expected_m, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +# M-grouped FP8 aliases +m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous +m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous +m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked + +# Legacy aliases +fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked + + +# K-grouped FP8 GEMM ops + + +def k_grouped_fp8_gemm_tn_contiguous( + a, b, d, ks, ks_tensor, c=None, recipe=(1, 1, 128), compiled_dims="mn" +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2 = recipe + ops.k_grouped_fp8_gemm_tn_contiguous( + a_data, a_sf, b_data, b_sf, d, ks_tensor, c, r0, r1, r2, compiled_dims + ) + + +def k_grouped_fp8_gemm_nt_contiguous( + a, b, d, ks, ks_tensor, c=None, recipe=(1, 1, 128), compiled_dims="mn" +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2 = recipe + ops.k_grouped_fp8_gemm_nt_contiguous( + a_data, a_sf, b_data, b_sf, d, ks_tensor, c, r0, r1, r2, compiled_dims + ) + + +# BF16 GEMM ops + + +def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"): + ops.bf16_gemm_nt(a, b, d, c, compiled_dims) + + +def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"): + ops.bf16_gemm_nn(a, b, d, c, compiled_dims) + + +def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"): + ops.bf16_gemm_tn(a, b, d, c, compiled_dims) + + +def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"): + ops.bf16_gemm_tt(a, b, d, c, compiled_dims) + + +# M-grouped BF16 GEMM ops + + +def m_grouped_bf16_gemm_nt_contiguous( + a, + b, + d, + grouped_layout, + compiled_dims="nk", + use_psum_layout=False, + expected_m_for_psum_layout=None, +): + has_em = expected_m_for_psum_layout is not None + em = expected_m_for_psum_layout if has_em else 0 + ops.m_grouped_bf16_gemm_nt_contiguous( + a, b, d, grouped_layout, compiled_dims, use_psum_layout, em, has_em + ) + + +def m_grouped_bf16_gemm_nn_contiguous( + a, b, d, grouped_layout, compiled_dims="nk", use_psum_layout=False +): + ops.m_grouped_bf16_gemm_nn_contiguous( + a, b, d, grouped_layout, compiled_dims, use_psum_layout + ) + + +def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims="nk"): + ops.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims) + + +# Legacy alias +bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked + + +# K-grouped BF16 GEMM ops + + +def k_grouped_bf16_gemm_tn_contiguous( + a, b, d, ks, ks_tensor, c=None, compiled_dims="mn" +): + ops.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks_tensor, c, compiled_dims) + + +# cuBLASLt GEMM ops + + +def cublaslt_gemm_nt(a, b, d, c=None): + ops.cublaslt_gemm_nt(a, b, d, c) + + +def cublaslt_gemm_nn(a, b, d, c=None): + ops.cublaslt_gemm_nn(a, b, d, c) + + +def cublaslt_gemm_tn(a, b, d, c=None): + ops.cublaslt_gemm_tn(a, b, d, c) + + +def cublaslt_gemm_tt(a, b, d, c=None): + ops.cublaslt_gemm_tt(a, b, d, c) + + +# Attention ops + + +def fp8_gemm_nt_skip_head_mid( + a, b, d, head_splits, recipe=None, compiled_dims="nk", disable_ue8m0_cast=False +): + a_data, a_sf = a + b_data, b_sf = b + left, mid, right = head_splits + has_recipe = recipe is not None + r0, r1, r2 = recipe if has_recipe else (0, 0, 0) + ops.fp8_gemm_nt_skip_head_mid( + a_data, + a_sf, + b_data, + b_sf, + d, + left, + mid, + right, + r0, + r1, + r2, + has_recipe, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_mqa_logits( + q, + kv, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits=True, + max_seqlen_k=0, +): + kv_data, kv_sf = kv + return ops.fp8_mqa_logits( + q, + kv_data, + kv_sf, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits, + max_seqlen_k, + ) + + +def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms): + return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms) + + +def fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits=False, +): + return ops.fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits, + ) + + +# Einsum ops + + +def einsum(expr, a, b, d, c=None, use_cublaslt=False): + ops.einsum(expr, a, b, d, c, use_cublaslt) + + +def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2 = recipe + ops.fp8_einsum(expr, a_data, a_sf, b_data, b_sf, d, c, r0, r1, r2) + + +# Hyperconnection ops + + +def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None): + has_ns = num_splits is not None + ns = num_splits if has_ns else 0 + ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns) + + +# Initialize the C++ runtime + + +def _find_cuda_home() -> str: + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home is None: + try: + with open(os.devnull, "w") as devnull: + nvcc = ( + subprocess.check_output(["which", "nvcc"], stderr=devnull) + .decode() + .rstrip("\r\n") + ) + cuda_home = os.path.dirname(os.path.dirname(nvcc)) + except Exception: + cuda_home = "/usr/local/cuda" + if not os.path.exists(cuda_home): + cuda_home = None + assert cuda_home is not None, "Could not find CUDA installation" + return cuda_home + + +# Find the library root for JIT headers +# In development: use the repo's deep_gemm/ directory +# In installed wheel: use this package's directory +_lib_root = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "deep_gemm" +) +if not os.path.isdir(os.path.join(_lib_root, "include")): + # Fallback: try the parent package + _lib_root = os.path.dirname(os.path.abspath(__file__)) + +_initialized = False + +# Set DG_CUTLASS_INCLUDE for JIT kernel compilation (if not already set by user) +if "DG_CUTLASS_INCLUDE" not in os.environ: + _include = os.path.join(_lib_root, "include") + _cutlass_include_candidates = [ + _include, # legacy layout: include/cutlass + os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout + ] + for _cutlass_include in _cutlass_include_candidates: + if os.path.isdir(os.path.join(_cutlass_include, "cutlass")): + os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include + break + else: + # Fall back to nvidia-cutlass pip package + try: + import nvidia.cutlass as _nc + os.environ["DG_CUTLASS_INCLUDE"] = os.path.join( + os.path.dirname(_nc.__file__), "include" + ) + except ImportError: + pass + +def _ensure_initialized(): + global _initialized + if _initialized: + return + _initialized = True + ops.init(_lib_root, _find_cuda_home()) + + +# Try to initialize eagerly, but don't fail if CUDA is not found +# (e.g., during build-time import checks). init() will be called +# lazily on first actual kernel use. +try: + _ensure_initialized() +except (AssertionError, RuntimeError): + pass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so b/build/torch29-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..2bcc47787dd4ad8cd873d376e29d7424560f848c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61d39794e075faf37aefde0723de136f5792e7b6a1951fc2b05c7047611dec13 +size 2891904 diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a1598d972cf45498451b415d7a5869693e9b4a96 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _deep_gemm_cuda_a68a39f +ops = torch.ops._deep_gemm_cuda_a68a39f + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_deep_gemm_cuda_a68a39f::{op_name}" diff --git a/build/torch29-cxx11-cu130-aarch64-linux/deep_gemm/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/deep_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/deep_gemm/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/cute_tie.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/cute_tie.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cd2aace7a8b8dd642f4c149bfc974c3d21e5f5b5 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/cute_tie.cuh @@ -0,0 +1,48 @@ +#pragma once + +namespace cute { + +struct ignore_t { + template + constexpr const ignore_t& operator=(T&&) const noexcept { + return *this; + } +}; + +inline constexpr ignore_t ignore{}; + +} // namespace cute + +#define CUTE_TIE_CONCAT_IMPL(A, B) A##B +#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B) + +#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define CUTE_TIE_COUNT_ARGS(...) \ + CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) + +#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get(TUPLE) +#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get(TUPLE) + +#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1); +#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2); +#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); +#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); +#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5); + +#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_DECL, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ) + +#define CUTE_TIE(TUPLE_EXPR, ...) \ + do { \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_ASSIGN, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ); \ + } while (0) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/epilogue_utils.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/epilogue_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5f6a7a1956b64771ee5294035bead6bc2e3dd850 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/epilogue_utils.cuh @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +struct EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 + and kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/reduction.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/reduction.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d9e35f73128427d4c842c5382798c3866c27124f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/reduction.cuh @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include + +#include + +// Operation functors +template struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +__forceinline__ __device__ T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +__forceinline__ __device__ T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/scheduler.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/scheduler.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f93b96ee64b934cd24e2567dd2b6b54be913408b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/scheduler.cuh @@ -0,0 +1,288 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +enum class IndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto& candidate: {8u, 16u}) { + const auto& usage = kIsMulticastOnA ? + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx = 0; + // Only used for masked layout + uint32_t current_m_cumsum = 0; + // Only used for countiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + __device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = __ldg(grouped_layout + group_idx); + if (shape_k > 0) + break; + } + } + + // ReSharper disable once CppPossiblyUninitializedMember + __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k, + int* grouped_layout = nullptr) { + num_m_blocks = ceil_div(shape_m, BLOCK_M); + num_n_blocks = ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + num_blocks = num_m_blocks * num_n_blocks; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = __ldg(grouped_layout); + num_m_blocks = ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto& group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == IndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == IndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == IndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + current_group_idx)), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = align(current_psum_m, 128u); + current_psum_m = __ldg(grouped_layout + current_group_idx); + current_m_block_cumsum += num_m_blocks; + num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with 128 + m_block_idx += last_psum_m / BLOCK_M; + DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M"); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks) + break; + + // Move to check the next group + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto& block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx); + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/sm100_utils.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/sm100_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..537cbe0818bf0c2c8b693e5d49862dcfbc7adb11 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/sm100_utils.cuh @@ -0,0 +1,266 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace deep_gemm::sm100 { + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + uint32_t stride_byte_offset, uint32_t leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +__device__ __forceinline__ +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +__device__ __forceinline__ +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +__device__ __forceinline__ +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +template +__device__ __forceinline__ +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + const auto& layout_type = to_umma_layout_type(); + const auto& num_non_contiguous = 128 / get_atom_base(layout_type); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +__device__ __forceinline__ +uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; + return static_cast(static_cast(desc)) << 32; +} + +template +__device__ constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if (kNumCols <= 32) return 32; + if (kNumCols <= 64) return 64; + if (kNumCols <= 128) return 128; + if (kNumCols <= 256) return 256; + return 512; +} + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +__device__ __forceinline__ +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + +// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F16BF16_WS_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +} // namespace `deep_gemm::sm100` diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/sm90_utils.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/sm90_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0874b675bc3e2eaaa90c5ff508a7da0bdd8f0830 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/sm90_utils.cuh @@ -0,0 +1,332 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm::sm90 { + +template +struct FP8MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template +struct BF16MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct TF32MMARS { + + template + __forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; + +template +struct SM90_U32x2_STSM_N { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +struct SM90_U32x2_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +struct SM90_U32x4_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +__forceinline__ __device__ void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +__forceinline__ __device__ void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +template +__device__ cute::GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type, + const int& leading_byte_offset = 0, + const int& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; + const auto& uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; +} + +template +__device__ __forceinline__ +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto& layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } +} + +} // namespace `deep_gemm::sm90` diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/tma_utils.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/tma_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bd54adc231812a0c0e2fbbc130e18a27697a497c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/tma_utils.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +// Tensormap related +__device__ __forceinline__ void tensor_map_release_cta() { + asm volatile ("fence.proxy.tensormap::generic.release.cta;"); +} + +__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +} // namespace `deep_gemm` diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/types.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..410c5469ef348500f339ae0ced5d7bf171fdd42c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/types.hpp @@ -0,0 +1,41 @@ +#pragma once + +namespace deep_gemm { + +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, +}; + +constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/utils.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8fb6c2fc53b6d1eb067d13c113462a9f7de4133a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/common/utils.cuh @@ -0,0 +1,183 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "cute_tie.cuh" + +#ifdef __CLION_IDE__ + +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +namespace deep_gemm { + +template +struct PatternVisitor { + FuncT func; + + __device__ __host__ + explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} + + __device__ __host__ + auto operator [](const uint32_t& i) { + return func(i); + } +}; + +template +__device__ __host__ T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ T align(T a, T b) { + return ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +__forceinline__ __device__ void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +__forceinline__ __device__ uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +__forceinline__ __device__ uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float2 ld_shared(const float2* ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); +} + +__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); +} + +__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); +} + +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); +} + +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + +template +__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +__device__ __forceinline__ void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + +template +struct Vectorized { + static auto zeros() { + // TODO: add `ulonglong4` for SM100 once `__ldg` support this + if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) { + return make_uint4(0, 0, 0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) { + return make_uint2(0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) { + return 0; + } else { + DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization"); + } + } + + using vec_t = decltype(zeros()); +}; + +} // namespace `deep_gemm` diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0227b3e80061409c4dcf89f3f402ce408751246f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -0,0 +1,482 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `umma_arrive()` overhead + constexpr bool kDoMergeStages = + kNumStages_ >= 8 and kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 8; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; + constexpr uint32_t kNumTMAStoreStages = 2; + DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive only at the leader CTA + full_barriers[i]->init(kNumMulticast); + // Arrive at all CTAs + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + if constexpr (kTensorCoreUtilControl < 100) + tensor_core_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = (stage_idx + 1) % kNumStages; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // UMMA and empty barrier arrival alias + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc); + } + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + + // Let tensor cores relax for lower possibility of frequency drop + DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control"); + if constexpr (kTensorCoreUtilControl < 100) { + // For utilization control + umma_arrive(reinterpret_cast(tensor_core_full_barrier)); + + // Wait for last UMMA to be done + tensor_core_full_barrier->wait(tensor_core_phase); + tensor_core_phase ^= 1; + + // Sleep for certain cycles + constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; + const auto& start_clock = clock64(); + if (cute::elect_one_sync()) + while (clock64() - start_clock < kNumDummyCycles) {} + __syncwarp(); + } + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } + cute::tma_store_arrive(); + } + } + } + } + + // Deallocate tensor memory by the last UMMA store warp + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) + Allocator().free(0, kNumTmemCols); + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh new file mode 100644 index 0000000000000000000000000000000000000000..86303347d9c7a3a93b65a16d6ad4a7b73eb2ad1a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -0,0 +1,265 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumThreads, 1) +sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size"); + DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode"); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Shared memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Fill D/A/B + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE)); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Block indices + const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + if (warp_idx == 0) { + // TMA load warp + for (uint32_t s = 0; s < num_total_stages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + uint32_t m_idx = BLOCK_M * m_block_idx; + uint32_t n_idx = BLOCK_N * n_block_idx; + uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + uint32_t k_idx = sk_idx % SHAPE_K; + uint32_t s_idx = sk_idx / SHAPE_K; + + // Issue TMAs + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + } + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (cute::elect_one_sync()) + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } else if (warp_idx == 1) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + constexpr uint32_t UMMA_M = LAYOUT_AD_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Wait tensor memory empty barrier arrival + tcgen05_after_thread_sync(); + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + if (warp_idx == 2) + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + if (s >= kNumTMAStoreStages) { + if (warp_idx == 0 and cute::elect_one_sync()) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = s % kNumTMAStoreStages; + const auto m_idx = m_block_idx * BLOCK_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is doing TMA stores + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh new file mode 100644 index 0000000000000000000000000000000000000000..45a603add3f494aed51dce7aec53b5545bdc23f4 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -0,0 +1,563 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; + constexpr uint32_t kNumTMAStoreStages = 2; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4); + const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + tma_copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M, + scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad))); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + tma_copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N, + scheduler.template get_global_idx(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx)); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); + } + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32; + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); + tcgen05_after_thread_sync(); + + // Do SF copy at certain stages + // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + __syncwarp(); + + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc, + kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32), + kTmemStartColOfSFB); + } + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose for UTCCP at certain stages + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } + cute::tma_store_arrive(); + } + } + } + } + + // Deallocate tensor memory by the last UMMA store warp + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) + Allocator().free(0, kNumTmemCols); + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..180a308b3279b38827741942917a31e103b15b52 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -0,0 +1,404 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; +using namespace deep_gemm::sm100; + +template +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint64_t stride_logits, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + float* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64` + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + + // Types + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warp_in_group_idx = warp_idx % 4; + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); + + // Align to 512 bytes for swizzle-64B + extern __shared__ __align__(512) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_WEIGHT_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + + // TMA configs + constexpr uint32_t kNumTmemCols = BLOCK_Q * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); + auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); + + // Tensor memory allocation + auto tmem_ptr_in_smem = reinterpret_cast(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2); + + // Initialize barriers + DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads"); + const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32)); + const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1)); + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (is_umma_warp) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 24; + constexpr uint32_t kNumMathRegisters = 240; + + // Block scheduler + uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; + const auto& get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + gridDim.x, q_iter_idx + 1}; + }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); + } + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + // UMMA settings + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + + if (is_tma_load_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else if (is_umma_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue UMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); + tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } else if (warp_idx >= kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); + const auto& warp_offset = warp_idx * 32; + const auto& v_offset = lane_idx; + + // Preload weights + constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads); + float weights[BLOCK_Q][kNumWeightsInReg]; + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) { + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset); + + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); + tcgen05_after_thread_sync(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + + constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems"); + uint32_t shifted_accum[kNumLDTMElems]; + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = 0; j < kNumWeightsInReg; j += 4) { + sum_0 = transform_reg(j, sum_0); + sum_1 = transform_reg(j + 2, sum_1); + } + + const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { + sum_0 = transform_smem(j, sum_0); + sum_1 = transform_smem(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result; + } else { + logits[q_idx * stride_logits + kv_offset + v_offset] = result; + } + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + + // Free tensor memory + __syncthreads(); + if (is_tma_load_warp) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); +} + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7058c40f4f195de94184d3e7ebc6f9aa2eb3670f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -0,0 +1,398 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; +using namespace deep_gemm::sm100; + +template +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint64_t logits_stride, const uint64_t block_table_stride, + const uint32_t* context_lens, float* logits, + const uint32_t* block_table, const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); + }); + constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); + auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); + + constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4); + const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4); + const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1); + + // Initialize barriers + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + cutlass::arch::fence_barrier_init(); + } + if (is_umma_warp) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Scheduler + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); + + // Q and KV pipeline + const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + // UMMA settings + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = kNextN * kNumHeads; + DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + + if (is_tma_load_warp) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0); + } + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + if (cute::elect_one_sync()) { + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else if (is_umma_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 1; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + if (q_idx != next_q_idx) { + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + } + + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(umma_phase); + tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + umma_phase ^= 1; + } + } else if (is_math_warp) { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); + const uint32_t thread_idx = threadIdx.x; + + // Weights + constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads)); + float weights[kNextN][kNumWeightsInReg]; + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 0; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV; + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx); + + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait(umma_phase); + tcgen05_after_thread_sync(); + umma_phase ^= 1; + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN; + uint32_t shifted_accum[kNumLDTMElems]; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM"); + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = 0; j < kNumWeightsInReg; j += 4) { + sum_0 = transform_reg(j, sum_0); + sum_1 = transform_reg(j + 2, sum_1); + } + + const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { + sum_0 = transform_smem(j, sum_0); + sum_1 = transform_smem(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * logits_stride + thread_idx] = result; + } + } + } else { + cutlass::arch::warpgroup_reg_dealloc(); + } + + // Free tensor memory + __syncthreads(); + if (is_umma_warp) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); +} + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..4e4ff21d0746cff7bc7ecaf23a49278a2f5810cc --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,345 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__device__ __forceinline__ +uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { + // Calculate the index of the bank group to be written in the atom + const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` + // - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)` + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % (kSwizzleMode / kSwizzleBase); + + return row * 128 + col * kSwizzleBase; +} + +template +__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t kNumCastStages = 2; + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + constexpr auto kMajorA = cute::UMMA::Major::K; + constexpr auto kMajorB = cute::UMMA::Major::K; + DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages"); + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 4 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + full_cast_barriers[i]->init(kNumCastAndReduceThreads); + empty_barriers[i]->init(1); + empty_cast_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + + // Dispatch warps into different roles + if (warp_idx < kNumMMAThreads / 32) { + // TMA load warp + if (warp_idx == 0 and cute::elect_one_sync()) { + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } + + // MMA issue warp + if (warp_idx == 1) { + // Make instruction descriptor + constexpr uint32_t UMMA_M = BLOCK_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(float); + constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float); + using umma_t = cute::SM100_MMA_TF32_TS; + auto instr_desc = cute::UMMA::make_instr_desc(); + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Launch MMAs + // We can not unroll this part + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + const auto& cast_stage_idx = s % kNumCastStages; + full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; + const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; + const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); + } + + // Commit + cutlass::arch::umma_arrive(reinterpret_cast(empty_cast_barriers[cast_stage_idx])); + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + + // Commit to epilogue threads + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Only support layout F (M = 64) and D (M = 128) + DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Source and destination memory address + uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup; + auto smem_ptr = reinterpret_cast(smem_cd) + // Base pointer + warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset + get_swizzled_smem_offset(i, lane_idx); // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + if constexpr (BLOCK_M == 64) + __syncwarp(); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0); + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } else { + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32; + + // TODO: make even larger block K + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + + // Launch reductions + float2 sum[2] = {float2{0, 0}, float2{0, 0}}; + #pragma unroll kNumStages + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b) + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + const auto& smem_base_ptr = reinterpret_cast(smem_a[stage_idx]) + // Base pointer + sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset + + // 4 lanes shared a bank group + uint32_t uint32_values[2][kNumLoads]; + DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads"); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; i += 2) { + auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); + sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); + } + + // Wait tensor memory empty + const auto& cast_stage_idx = s % kNumCastStages; + empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1); + + // Cast, reduce and store into tensor memory + float2 fp32x2_values[2][kNumLoads]; + const auto& upper_view = reinterpret_cast(&fp32x2_values[0]); + const auto& lower_view = reinterpret_cast(&fp32x2_values[1]); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast(&uint32_values[u][i])); + sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]); + } + + // Store upper and lower part at the same time + const auto idx_0 = i * 2, idx_1 = i * 2 + 1; + cute::SM100_TMEM_STORE_16dp256b1x::copy( + upper_view[idx_0], upper_view[idx_1], + lower_view[idx_0], lower_view[idx_1], + cast_stage_idx * BLOCK_K + i * 8); + } + cutlass::arch::fence_view_async_tmem_store(); + + // Arrive for issuing MMAs + tcgen05_before_thread_sync(); + full_cast_barriers[cast_stage_idx]->arrive(); + } + + // Intra-warp reduction and write back + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + if (lane_idx % 4 == 0 and m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum; + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7a77e4e8fbbbffa56e8c8632ade7ae7938b30ee9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -0,0 +1,381 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `warpgroup_wait<0>()` overhead + constexpr uint32_t kDoMergeStages = + kNumStages_ >= 10 and + kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K and + kNumMathThreads == 128; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 5; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B/D must be aligned to 1024 bytes"); + + // D/A/B shared memory + auto smem_d = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 48; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(kNumTMAThreads >= 128, "Need at least 128 threads for TMA warp-group"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + + const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); + auto b_desc = make_gmma_desc(smem_b[0], 0, 0); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= 64 or warp_idx < kNumWGMMAStoreThreads / 32; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + // TODO: remove some useless computation for unaligned Ms + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; + a_desc.reg32_[0] = advance_gmma_desc_lo( + a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K); + b_desc.reg32_[0] = advance_gmma_desc_lo( + b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K); + WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1); + } + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + } + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); + + if constexpr (cute::is_same_v) { + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + } else { + // Use `st.shared` if STSM is not available + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2); + auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); + st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); + } + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); + + // Use TMA store to write back to global memory + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh new file mode 100644 index 0000000000000000000000000000000000000000..191a4fe2c4ccf66b0743affedcbfd17950e2618f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -0,0 +1,174 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + float *d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Shared memory + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M"); + DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads"); + DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads"); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + // Fill shared memory pointers + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block indices + const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + #pragma unroll + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1); + + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + const uint32_t& k_idx = sk_idx % SHAPE_K; + const uint32_t& s_idx = sk_idx / SHAPE_K; + + constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16); + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + float accum[WGMMA::kNumAccum] = {0}; + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrivals + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, 1); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + empty_barriers[stage_idx]->arrive(); + } + + const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; + const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + if (col + i * 8 >= SHAPE_N) + break; + if (row < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 0) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 0], accum[i * 4 + 1])); + } + if (row + 8 < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 8) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 2], accum[i * 4 + 3])); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cdd28fcb59d3b038c84c007ef1da1477d7ca263a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -0,0 +1,349 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, + int* grouped_layout, + cute::TmaDescriptor* tensor_map_buffer, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_b_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = threadIdx.x % 32; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a_base); + cute::prefetch_tma_descriptor(&tensor_map_b_base); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Tensor maps on shared and global memory + auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * i); + }); + auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * (2 + i)); + }); + auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; }); + auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; }); + + // Data on shared memory + auto smem_d = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE)); + }); + auto smem_sfb = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); + }); + + // Barriers on shared memory + constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE); + auto full_barriers = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast(sizeof(Barrier)))); + }); + auto empty_barriers = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); + }); + + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // Load tensormap A/B to shared memory + if constexpr (kGemmType == GemmType::KGroupedContiguous) { + *smem_tensor_map_a[0] = tensor_map_a_base; + *smem_tensor_map_a[1] = tensor_map_a_base; + *smem_tensor_map_b[0] = tensor_map_b_base; + *smem_tensor_map_b[1] = tensor_map_b_base; + } + + // Initialize barriers + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Pipeline unroll control + constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages); + + // Register reconfigurations (more math registers are needed with unrolling) + constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24); + constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // TMA and MMA pipeline + const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { + return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase + }; + uint32_t iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base; + const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base; + uint32_t last_group_idx = kNumGroups, sum_k = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const uint32_t& m_idx = m_block_idx * BLOCK_M; + const uint32_t& n_idx = n_block_idx * BLOCK_N; + + if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) { + const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1; + const uint32_t& next_stage_idx = stage_idx ^ 1; + last_group_idx = scheduler.current_group_idx; + + // Prepare next tensor map + sum_k += scheduler.current_shape_k; + if (scheduler.next_group_idx < kNumGroups) { + tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast(sum_k) * shape_m); + tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast(sum_k) * shape_n); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); + *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); + *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]); + tensor_map_release_cta(); + } + + // Get current tensor map + if (scheduler.current_num_valid_groups > 0) { + tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]); + tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]); + current_tensor_map_a = gmem_tensor_map_a[stage_idx]; + current_tensor_map_b = gmem_tensor_map_b[stage_idx]; + } + } + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t& k_idx = k_block_idx * BLOCK_K; + const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; + tma_copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma_copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma_copy(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma_copy(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) { + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; + const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Accumulation for WGMMA or CUDA promotion + DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); + const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); + const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + float2 scales_b[WGMMA::kNumAccum / 4]; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait TMA arrivals + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + full_barriers[stage_idx]->wait(phase); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0); + auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1); + + // Read B scales + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + scales_b[i] = ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + + // Promote with scales + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + const float &scale_b_0 = scales_b[i].x; + const float &scale_b_1 = scales_b[i].y; + final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; + } + } + + // Flush previous stores + if (warp_idx % 4 == 0 and cute::elect_one_sync()) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Store to D shared memory + const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Use TMA store to write back to global memory + if (warp_idx % 4 == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy( + &tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N, + current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9247304cdd17d8e2c3a5cdb31c78c191ae6b76ec --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -0,0 +1,440 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { + if (num_former_iters == kNumFormerIters) { + func(cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + dispatch_num_former_iters(num_former_iters, func); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); + const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K); + const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + + // Configs + const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); + auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA A + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t k_idx = k_block_idx * BLOCK_K; + tma_copy(&tensor_map_a, &full_barrier, + smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a, batch_idx); + tma_copy(&tensor_map_sfa, &full_barrier, + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), + num_tma_multicast_a); + + // Issue TMA B + tma_copy(&tensor_map_b, &full_barrier, + smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b, batch_idx); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = make_smem_desc(smem_b[0], 1); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; + const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; + auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; + + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) + st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb)); + } + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + + // Accumulation for WGMMA or CUDA promotion + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= WGMMA::M or warp_idx < kNumWGMMAStoreThreads / 32; + + // Empty barrier arrival + auto empty_barrier_arrive = [&]() { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void(); + } + }; + + // Skip useless computations + if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { + // The compiler must know the dynamic variable `num_former_iters`'s real value + constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // Dispatch `num_former_iters` and launch MMAs + dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) { + #pragma unroll 8 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + + // Read B scales + float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales); + + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); + + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; + auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16; + b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16; + WGMMA::wgmma(a_desc, b_desc, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(); + + // Skip promotion for the unfilled parts + if (not do_wgmma_store) + continue; + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + } + }); + } else { + #pragma unroll + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + empty_barrier_arrive(); + } + } + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + in_block_n_offset); + auto m_idx = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr, + n_idx, m_idx, scheduler.current_group_idx); + } else { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d58c716242a09922157aa13e16cb8afac477904c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -0,0 +1,329 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +// ReSharper disable once CppNotAllPathsReturnValue +template +static constexpr int to_swizzle_cute_type() { + DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); + if constexpr (kHeadDim == 32) + return static_cast(cute::SM90::GMMA::LayoutType::B32); + if constexpr (kHeadDim == 64) + return static_cast(cute::SM90::GMMA::LayoutType::B64); + if constexpr (kHeadDim == 128) + return static_cast(cute::SM90::GMMA::LayoutType::B128); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint64_t stride_logits, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + float* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + + // Initialize barriers + const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 32; + constexpr uint32_t kNumMathRegisters = 112; + + // Block scheduler + uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; + const auto& get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + gridDim.x, q_iter_idx + 1}; + }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); + } + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // Only the first warp remains + if (not is_tma_load_warp) + return; + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& thread_idx = threadIdx.x % kNumMathThreads; + const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4]; + + const auto& warp_offset = warp_idx * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); + float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, + to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, + to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto& transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0; + if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1; + } else { + logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0; + logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1; + } + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } +} + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..482a85a80fce29aa949b464070b0b20fb55ae030 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -0,0 +1,413 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +template +__global__ __launch_bounds__(32, 1) +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, + const uint32_t* context_lens, uint32_t* schedule_metadata) { + DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); + const uint32_t lane_idx = get_lane_idx(); + + uint32_t num_segs[kAlignedBatchSize / 32]; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + const uint32_t q_idx = k * 32 + lane_idx; + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0); + num_segs[k] = ceil_div(context_len, SPLIT_KV); + } + + __shared__ uint32_t prefix_sum[kAlignedBatchSize]; + uint32_t sum = 0; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + uint32_t x = num_segs[k]; + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset); + x += (lane_idx >= offset ? y : 0); + } + x += sum; + prefix_sum[k * 32 + lane_idx] = x; + sum = __shfl_sync(0xffffffff, x, 31); + } + + const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t q_idx = 0; + while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts) + ++ q_idx; + const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]); + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } +} + +template +struct PagedMQALogitsScheduler { + uint32_t batch_size; + const uint32_t* context_lens; + + uint32_t current_q_idx, current_kv_idx; + uint32_t end_q_idx, end_kv_idx; + uint32_t current_num_kv; + + __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) { + const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0; + } + + __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx, + const uint32_t* context_lens, const uint32_t* schedule_meta) { + this->batch_size = batch_size; + this->context_lens = context_lens; + + const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); + const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); + current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; + + current_num_kv = get_num_kv(current_q_idx); + } + + __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) { + q_idx = current_q_idx; + kv_idx = current_kv_idx; + num_kv = current_num_kv; + + if (q_idx == end_q_idx and kv_idx == end_kv_idx) + return false; + + current_kv_idx += kNumBlocksPerSplit; + if (current_kv_idx >= current_num_kv) { + ++ current_q_idx; + current_kv_idx = 0; + current_num_kv = get_num_kv(current_q_idx); + } + + return true; + } + + __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const { + return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx; + } +}; + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint64_t logits_stride, const uint64_t block_table_stride, + const uint32_t* context_lens, float* logits, + const uint32_t* block_table, const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + + constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + + constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q data and barriers on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + + // Separate math warpgroups and tma load warps into KV groups + // Each math warpgroup corresponds to a tma load warp + const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + + // Per group KV data and barriers on shared memory + const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + + // Initialize barriers + if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) { + if (kv_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + } + if (kv_group_idx < kNumMathWarpGroups) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(128); + } + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 64; + constexpr uint32_t kNumMathRegisters = 104; + + // Scheduler + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); + + // Q and KV pipeline + const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + if (kv_group_idx >= kNumMathWarpGroups) + return; + + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (kv_group_idx == 0 and cute::elect_one_sync()) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? + __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); + } + const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4]; + const auto& sub_warp_offset = (warp_idx % 4) * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + + // Read per-KV scales + float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); + float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); + + // Wait WGMMA + warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto& transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * logits_stride + v_0_offset] = v_0; + logits[kv_offset + i * logits_stride + v_1_offset] = v_1; + } + } + } +} + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e3bf98478923a2bf560e69e6ecc802d218fb82c1 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,287 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ __forceinline__ +uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { + constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; + + const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % kGroupsInSwizzleRange; + + return (row * kNumBankGroups + col) % kGroupsInSwizzleRange; +} + +template +__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // kSwizzleAMode and kSwizzleBMode must be 128 for now + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K"); + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode"); + + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 256; + + // TMA load warp + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cutlass::arch::warpgroup_reg_dealloc(); + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + + for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + } + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + constexpr uint32_t WGMMA_M = 64; + constexpr uint32_t WGMMA_N = BLOCK_N; + constexpr uint32_t WGMMA_K = 8; + + using WGMMA = typename TF32MMASelector::type; + float accum[WGMMA::kNumAccum] = {0}; + + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + float sqr_sum_acc_0 = 0; + float sqr_sum_acc_1 = 0; + + #pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2 + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128; + constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K; + + float a[kNumRegPerWgmma * kNumWgmmaPerBlockK]; + // Assume swizzle A mode is 128 + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + + // Load BF16 A fragment from shared memory into registers, and transpose to FP32 + uint32_t row = warp_idx * 16 + lane_idx / 4; + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + // Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a + uint32_t bank_group_idx = (row ^ i) % 8; + nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + + uint32_t elem_offset = lane_idx % 4; + nv_bfloat16 a_bf16[kNumRegPerWgmma]; + a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset]; + a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4]; + a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset]; + a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4]; + + auto a_bf16x2_ptr = reinterpret_cast(a_bf16); + auto a_float2_ptr = reinterpret_cast(a); + float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]); + float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]); + a_float2_ptr[i * 2 + 0] = a_float2_0; + a_float2_ptr[i * 2 + 1] = a_float2_1; + sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x; + sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; + } + + warpgroup_wait<0>(); + if (s > 0) + empty_barriers[(s - 1) % kNumStages]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + + constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); + constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; + DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K"); + + #pragma unroll + for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { + #pragma unroll + for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { + auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); + } + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + } + + const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1); + + const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); + if (lane_idx % 4 == 0) { + if (m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum_0; + if (m_idx + 8 < shape_m) + sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; + } + warpgroup_wait<0>(); + empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); + + // Write accum to shared memory + // Every 2 threads (one pair) will write to the same bank group (16 bytes). + // Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d + uint32_t is_odd_pair = lane_idx / 2 % 2; + + // Four threads per group; write the data to the same row. + uint32_t row_idx = lane_idx / 4; + + // Even/odd index pairs write to the same column, we need to reorder idx: + // group even pair indices consecutively, and likewise for odd ones. + uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx; + + auto shifted_smem_ptr = reinterpret_cast(smem_cd) + + (warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows + lane_idx % 2 * 8; // One thread of a pair writes 8 bytes + + #pragma unroll + for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) { + // Get the swizzled bank group index (16 bytes per group) + uint32_t bank_group_idx = get_swizzled_bank_group_idx(i + is_odd_pair, reordered_pair_idx); + auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group + + // 0/1 write to the same row, 2/3 write to another row + auto values = reinterpret_cast(accum + i * 2); + st_shared(smem_ptr, values[0], values[1]); + st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, 1); + + // Issue TMA stores + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cc9e5e6b0c7ce95acf0b7149221dc4d4f0f83a21 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include + +namespace deep_gemm { + +template +__global__ __launch_bounds__(kNumWarps * 32, 1) +void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits, + const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) { + const uint32_t& num_sms = gridDim.x; + const uint32_t& sm_idx = blockIdx.x; + const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + constexpr float neg_inf = -cute::numeric_limits::infinity(); + + // Allocate filled `-inf` shared memory + extern __shared__ __align__(1024) float smem_buffer[]; + #pragma unroll + for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32) + smem_buffer[i] = neg_inf; + cute::tma_store_fence(); + __syncthreads(); + + // Assign sequence to each warp + const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx, + const uint32_t& start, const uint32_t& total) -> cute::tuple { + const auto& per = total / num, rem = total % num; + return {start + idx * per + min(idx, rem), per + (idx < rem)}; + }; + CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len); + CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len); + + if (cute::elect_one_sync()) { + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); + const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; + const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + + for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { + const auto& right = min(left + BLOCK_KV, static_cast(stride_logits)); + if (right <= ks or ke <= left) { + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float)); + } else { + if (left < aligned_ks) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float)); + if (aligned_ke < right) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float)); + } + } + } + } + + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); + const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; + const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + for (uint32_t j = aligned_ks; j < ks; ++ j) + logits[i * stride_logits + j] = neg_inf; + for (uint32_t j = ke; j < aligned_ke; ++ j) + logits[i * stride_logits + j] = neg_inf; + } +} + +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/smxx_layout.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/smxx_layout.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bea7000276c3e382c1acfeff545d6181351849b6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/deep_gemm/impls/smxx_layout.cuh @@ -0,0 +1,176 @@ +#pragma once + +#include + +namespace deep_gemm { + +template +__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename Vectorized::vec_t in_vec_t; + constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); + constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; + + // Shapes and strides + extern __shared__ float smem_buffer[]; + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the block + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; + const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + + // Load + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { + auto in_vec = __ldg(local_sf + i); + const auto& in_values = reinterpret_cast(&in_vec); + + const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; + #pragma unroll + for (uint32_t j = 0; j < kNumElemsPerVec; ++ j) + smem_buffer[row * PADDED_SF_K + col + j] = in_values[j]; + } + __syncthreads(); + + // Store + #pragma unroll + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { + const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; + const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + } +} + +// NOTES: the two kernels below always pack the K dimension + +template +__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { + extern __shared__ uint32_t smem_buffer[]; + + // Shapes and strides + constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the group + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + + // Load FP32 SFs + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); + const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + const auto num_values = in_block_mn * SF_K; + const auto num_uint4 = num_values / 4; + #pragma unroll + for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { + const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); + st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + } + + // Fill unaligned values as well + if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) + st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx)); + __syncthreads(); + + // Pack into UE8M0 and store + #pragma unroll + for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) { + const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN; + + // Load shared memory + uint32_t values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + const auto sf_k_idx = sf_k_pack_idx * 4 + j; + values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + } + + // Pack and store + uint32_t packed = 0; + packed |= (values[0] >> 23u); + packed |= (values[1] >> 15u); + packed |= (values[2] >> 7u); + packed |= (values[3] << 1u); + if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn) + out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed; + } +} + +template +__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) { + // Always packing the K dimension + // NOTES: should also assert `mn % 4 == 0` at launch + DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes"); + + // Shapes and strides + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto in_block_mn_uint4 = in_block_mn / 4; + const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K); + + // Shift into the right block along MN + sf += blockIdx.x * BLOCK_MN; + out += blockIdx.x * BLOCK_MN; + + // Each warp is responsible for a packed row + const auto warp_idx = threadIdx.x / 32; + const auto lane_idx = get_lane_idx(); + const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; + if (warp_idx >= in_block_packed_sf_k) + return; + + // Make an offset on the input + uint32_t input_offset = 0; + if constexpr (kNumGroups > 1) { + // Load each group's size + DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups"); + uint32_t group_ks[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) { + const auto group_idx = lane_idx * 4 + i; + group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0; + } + __syncwarp(); + + // Make the offset + sf_k = 0; + auto sum_packed_sf_k = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumGroups; ++ i) { + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4); + sf_k += sf_k_in_group; + sum_packed_sf_k += ceil_div(sf_k_in_group, 4u); + if (packed_sf_k_idx < sum_packed_sf_k) + break; + if (const auto remainder = sf_k_in_group % 4; remainder > 0) + input_offset += 4 - remainder; + } + } + + for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + // Load + uint4 values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + values[j] = make_uint4(0, 0, 0, 0); + if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) + values[j] = __ldg(reinterpret_cast(sf + sf_k_idx * mn) + mn_idx); + } + + // Pack and store + uint4 packed; + packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u); + packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u); + packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u); + packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u); + reinterpret_cast(out + packed_sf_k_idx * mn)[mn_idx] = packed; + } +} + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/options.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/options.h new file mode 100644 index 0000000000000000000000000000000000000000..9cfb284e6716575c81a5f73e7746ce13664d2250 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/options.h @@ -0,0 +1,121 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Cutlass command line parser +#include "cutlass/util/command_line.h" + +class Options { +public: + + bool help; + bool good; + std::vector extent; ///< extent of tile to fill + std::vector stride; ///< stride vector for layout function + std::vector output_shape; ///< output shape + int vectorize; ///< sequences of consecutive output elements are concatenated into a vector + /// if, and only if, they were consecutive in source memory + +public: + + /// Options + Options(): + help(false), + good(true), + extent({32, 8}), + stride({32}), + output_shape({16, 8}), + vectorize(1) { + + } + + /// Constructs from command line parser + Options(cutlass::CommandLine const & cmd_line): help(false), good(true) { + + if (cmd_line.check_cmd_line_flag("help") || + cmd_line.check_cmd_line_flag("h")) { + + help = true; + } + + if (cmd_line.check_cmd_line_flag("extent")) { + cmd_line.get_cmd_line_arguments("extent", extent); + } + else { + extent = {32, 8}; + } + + if (cmd_line.check_cmd_line_flag("stride")) { + cmd_line.get_cmd_line_arguments("stride", stride); + } + + int default_output_shape[] = {16, 8}; + + if (cmd_line.check_cmd_line_flag("output-shape")) { + cmd_line.get_cmd_line_arguments("output-shape", output_shape); + } + + for (int i = int(output_shape.size()); i < 2; ++i) { + output_shape.push_back(default_output_shape[i]); + } + + if (cmd_line.check_cmd_line_flag("vectorize")) { + cmd_line.get_cmd_line_argument("vectorize", vectorize); + } + else { + vectorize = 1; + } + + if (output_shape.front() % vectorize) { + + std::cerr << "Error: --vectorize=" << vectorize + << " must divide contiguous elements in --output-shape=" + << output_shape.at(0) << "," << output_shape.at(1) << std::endl; + + good = false; + } + } + + /// Prints usage statement + static void print_usage(std::ostream &out) { + out + << " Options:\n" + << " --help Displays this help message.\n" + << " --extent= Specifies the layout-specific extent (as comma-delimited array).\n" + << " --stride= Specifies the layout-specific stride vector (comma-delimited array)\n" + << " --output-shape= Specifies the dimensions of a row-major output matrix. \n" + << " --vectorize= If possible, vectorizes the output into vectors of consecutive elements\n"; + } +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/register_layout.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/register_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..c840c90e0ddf6f149a8ec82aba8a2c0fc6691d2c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/register_layout.h @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS layout visualization example +*/ + +#pragma once + +#include +#include + +#include "options.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct VisualizeLayoutBase { + virtual bool visualize(Options const &) = 0; + virtual bool verify(bool verbose, std::ostream &out) = 0; + virtual void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') = 0; + virtual std::ostream &print_help(std::ostream &out) { + return out; + } + virtual ~VisualizeLayoutBase() { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +void RegisterLayouts(std::map > &layouts); + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/visualize_layout.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/visualize_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..13318a05838039745e262a5a59a105701c4907cb --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/visualize_layout.h @@ -0,0 +1,383 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS layout visualization example +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/coord.h" +#include "cutlass/util/reference/host/tensor_foreach.h" + +#include "register_layout.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Permits copying dynamic vectors into static-length vectors +template +struct vector_to_coord { + + vector_to_coord(TensorCoord &coord, std::vector const &vec) { + + coord[Rank - 1] = vec.at(Rank - 1); + + if (Rank > 1) { + vector_to_coord(coord, vec); + } + } +}; + +/// Permits copying dynamic vectors into static-length vectors +template +struct vector_to_coord { + + vector_to_coord(TensorCoord &coord, std::vector const &vec) { + + coord[0] = vec.at(0); + } +}; + +/// Permits copying dynamic vectors into static-length vectors +template +struct vector_to_coord { + + vector_to_coord(TensorCoord &coord, std::vector const &vec) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +std::ostream &operator<<(std::ostream &out, std::vector const &vec) { + auto it = vec.begin(); + if (it != vec.end()) { + out << *it; + for (++it; it != vec.end(); ++it) { + out << ", " << *it; + } + } + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Permits copying static-length vectors into dynamic vectors +template +struct coord_to_vector { + + coord_to_vector(std::vector &vec, TensorCoord const &coord) { + + vec.at(Rank - 1) = coord[Rank - 1]; + coord_to_vector(vec, coord); + } +}; + +/// Permits copying static-length vectors into dynamic vectors +template +struct coord_to_vector { + + coord_to_vector(std::vector &vec, TensorCoord const &coord) { + + vec.at(0) = coord[0]; + } +}; + +/// Permits copying static-length vectors into dynamic vectors +template +struct coord_to_vector { + + coord_to_vector(std::vector &vec, TensorCoord const &coord) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure representing an element in source memory +struct Element { + + std::vector coord; ///< logical coordinate of element (as vector) + int offset; ///< linear offset from source memory + int color; ///< enables coloring each element to indicate + + /// Default ctor + inline Element(): offset(-1), color(0) { } + + /// Construct from logical coordinate and initial offset + inline Element( + std::vector const &coord_, + int offset_, + int color_ = 0 + ): + coord(coord_), offset(offset_), color(color_) { } + + /// Returns true if element is in a defined state + inline bool valid() const { + return offset >= 0; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visualizes memory layouts by constructing a 'shape' +template +class VisualizeLayout : public VisualizeLayoutBase { +public: + + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using Stride = typename Layout::Stride; + +public: + + Options options; + Layout layout; + TensorCoord extent; + std::vector elements; + +public: + + /// Initializes the problem space + VisualizeLayout() { + + } + + /// visualization method + bool visualize(Options const &options_) { + + options = options_; + + if (options.extent.size() != TensorCoord::kRank) { + + std::cerr + << "--extent must have rank " << TensorCoord::kRank + << " (given: " << options.extent.size() << ")" << std::endl; + + return false; + } + + vector_to_coord(extent, options.extent); + + // Construct the layout for a packed tensor + if (options.stride.empty()) { + + layout = Layout::packed(extent); + } + else if (options.stride.size() != Stride::kRank) { + + std::cerr + << "--stride must have rank " << Stride::kRank + << " (given: " << options.stride.size() << ")" << std::endl; + + return false; + } + else { + // Stride from + Stride stride; + vector_to_coord(stride, options.stride); + + layout = Layout(stride); + } + + // Resize elements, setting elements to 'undefined' state + elements.resize(layout.capacity(extent)); + + // enumerate points in tensor space and assign + cutlass::reference::host::TensorForEachLambda( + extent, + [&](TensorCoord coord) { + + std::vector coord_vec(TensorCoord::kRank, 0); + coord_to_vector(coord_vec, coord); + + int offset = int(layout(coord)); + + if (offset >= int(elements.size())) { + std::cerr + << "Layout error - " << coord_vec + << " is out of range (computed offset: " << offset + << ", capacity: " << elements.size() << std::endl; + + throw std::out_of_range("(TensorForEach) layout error - coordinate out of range"); + } + + elements.at(offset) = Element(coord_vec, offset); + }); + + return true; + } + + /// Verifies the layout satisfies vectorization requirements + bool verify(bool verbose, std::ostream &out) { + return true; + } + +private: + + /// returns a pair (is_vectorizable, one_changing_rank) to determine if a + /// vector exists (consecutive logical coordinates or uniformly invalid) + /// at the given location. + std::pair< bool, int > _is_vectorizable(int i) const { + // (all elements are invalid) or + // (all elements are valid AND + // exactly one rank is changing AND + // elements are consecutive) + + // Don't need vectorization. + if (options.vectorize <= 2) return std::make_pair(false, -1); + + // Boundary check. + if (i > int(elements.size()) || (i + options.vectorize - 1) > int(elements.size())) + return std::make_pair(false, -1); + + // Check if either all elements are valid or invalid. + bool all_elements_invalid = std::all_of( + elements.begin() + i, elements.begin() + i + options.vectorize, + [](Element const &e) { return !e.valid(); }); + + bool all_elements_valid = std::all_of( + elements.begin() + i, elements.begin() + i + options.vectorize, + [](Element const &e) { return e.valid(); }); + + if (!all_elements_invalid && !all_elements_valid) + return std::make_pair(false, -1); + + // From here, it is vectorizable. + if (all_elements_invalid) return std::make_pair(true, -1); + + // Check if only exactly one rank is changing. + int one_changing_rank = -1; + for (int j = 0; j < options.vectorize; ++j) { + for (int r = 0; r < TensorCoord::kRank; ++r) { + if (elements.at(i + j).coord.at(r) != elements.at(i).coord.at(r)) { + if (one_changing_rank == -1) { + one_changing_rank = r; + } else if (one_changing_rank != r) { + return std::make_pair(false, -1); + } + } + } + } + + return std::make_pair(true, one_changing_rank); + } + + /// Prints a vector of elements + void _print_vector(std::ostream &out, int i, int one_changing_rank) { + Element const &base_element = elements.at(i); + if (base_element.valid()) { + out << "("; + for (int r = 0; r < TensorCoord::kRank; ++r) { + if (r) { + out << ", "; + } + + if (r == one_changing_rank) { + out + << base_element.coord.at(r) + << ".." + << (base_element.coord.at(r) + options.vectorize - 1); + } + else { + out << base_element.coord.at(r); + } + } + out << ")"; + } + else { + out << " "; + } + } + + /// Prints a single element + void _print_element(std::ostream &out, int k) { + Element const &element = elements.at(k); + if (element.valid()) { + out << "("; + for (int v = 0; v < TensorCoord::kRank; ++v) { + out << (v ? ", " : "") << element.coord.at(v); + } + out << ")"; + } + else { + out << " "; + } + } + +public: + + /// Pretty-prints the layout to the console + void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') { + int row = -1; + + for (int i = 0; i < int(elements.size()); i += options.vectorize) { + if (i % options.output_shape.at(0)) { + out << delim; + } + else { + if (row >= 0) { + out << new_line; + } + ++row; + if (row == options.output_shape.at(1)) { + out << new_line; + row = 0; + } + } + + auto is_vector = _is_vectorizable(i); + + if (is_vector.first) { + _print_vector(out, i, is_vector.second); // print a vector starting at element i + } + else { + for (int j = 0; j < options.vectorize; ++j) { // print individual elements [i..i+j) + _print_element(out, i + j); + } + } + } + + out << new_line << std::flush; + } + + /// Help message + virtual std::ostream &print_help(std::ostream &out) { + out << "TensorCoord rank " << TensorCoord::kRank << ", Stride rank: " << Stride::kRank; + return out; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h new file mode 100644 index 0000000000000000000000000000000000000000..df4cb76ad13df4d4a55b5fce810be5361fb081cd --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h @@ -0,0 +1,719 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + + +template +class B2bNonFusedConv2dRun { +public: + + using Conv2d0 = Conv2d0_; + using Conv2d1 = Conv2d1_; + using ElementAccumulator = typename Conv2d0::ElementAccumulator; + using ElementCompute = typename Conv2d0::ElementCompute; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator; + static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator, + "Fused convolution operators must be the same"); + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + cutlass::HostTensor tensor_A0; + cutlass::HostTensor tensor_B0; + cutlass::HostTensor tensor_C0; + cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_D0_computed; + cutlass::HostTensor tensor_D0_reference; + + cutlass::HostTensor tensor_B1; + cutlass::HostTensor tensor_C1; + cutlass::HostTensor tensor_Bias1; + cutlass::HostTensor tensor_D1_computed; + cutlass::HostTensor tensor_D1_reference; + + +public: + + B2bNonFusedConv2dRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 16) { + scope = 2; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + uint64_t seed = 2019) { + + tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_Bias0.resize({1, 1, 1, problem_size_0.K}); + tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); + tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + + initialize_tensor(tensor_A0.host_view(), init_A, seed); + initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); + initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); + initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); + initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0_computed.sync_device(); + tensor_D0_reference.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1_computed.sync_device(); + tensor_D1_reference.sync_device(); + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + initialize(problem_size_0, problem_size_1); + + // configure the operator + Conv2d0 conv2d_op_0; + Conv2d1 conv2d_op_1; + + typename Conv2d0::Arguments conv2d_args_0( + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)}, + tensor_D0_computed.device_ref(), + {alpha0, beta0}, + split_k_mode + ); + typename Conv2d1::Arguments conv2d_args_1( + problem_size_1, + tensor_D0_computed.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)}, + tensor_D1_computed.device_ref(), + {alpha1, beta1}, + split_k_mode + ); + + + cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0); + + CUTLASS_CHECK(status); + + status = conv2d_op_1.initialize(conv2d_args_1); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = conv2d_op_0(); + CUTLASS_CHECK(status); + status = conv2d_op_1(); + CUTLASS_CHECK(status); + } + + // + // Run Conv2d + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + + for(int i = 0; i < runs; i++) { + // run conv2d operator + status = conv2d_op_0(); + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + + for(int i = 0; i < runs; i++) { + // run conv2d operator + status = conv2d_op_1(); + CUTLASS_CHECK(status); + } + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float conv2d0Time, conv2d1Time, totalTime; + cudaEventElapsedTime(&conv2d0Time, start, stop1); + cudaEventElapsedTime(&conv2d1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n"; + std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; + + tensor_D0_computed.sync_host(); + tensor_D1_computed.sync_host(); + + bool passed = false; + + cutlass::reference::device::Conv2d< + typename Conv2d0::ElementA, + typename Conv2d0::LayoutA, + typename Conv2d0::ElementB, + typename Conv2d0::LayoutB, + typename Conv2d0::ElementC, + typename Conv2d0::LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)}, + tensor_D0_reference.device_ref(), + alpha0, + beta0); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); + } + + cutlass::reference::device::Conv2d< + typename Conv2d1::ElementA, + typename Conv2d1::LayoutA, + typename Conv2d1::ElementB, + typename Conv2d1::LayoutB, + typename Conv2d1::ElementC, + typename Conv2d1::LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size_1, + tensor_D0_reference.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)}, + tensor_D1_reference.device_ref(), + alpha1, + beta1); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); + } + + cudaError_t result = cudaDeviceSynchronize(); + CHECK_TRUE(result == cudaSuccess); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D0_reference.sync_host(); + tensor_D1_reference.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); + + passed = cutlass::reference::host::TensorEquals( + tensor_D1_computed.host_view(), + tensor_D1_reference.host_view()); + + CHECK_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_B2bImplicitGemm_device_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream results(fname.str()); + + results << problem_size_0 << std::endl; + results << problem_size_1 << std::endl; + + results + << "\nA0:\n" << tensor_A0.host_view() << "\n" + << "\nB0:\n" << tensor_B0.host_view() << "\n" + << "\nC0:\n" << tensor_C0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n" + << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n" + << "\nB1:\n" << tensor_B1.host_view() << "\n" + << "\nC1:\n" << tensor_C1.host_view() << "\n" + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" + << "\nD1 computed:\n" << tensor_D1_computed.host_view(); + + + } + + return passed; + } + +}; + +template +class B2bFusedConv2dRun { +public: + + using B2bConv2d = B2bConv2d_; + using ElementAccumulator = typename B2bConv2d::ElementAccumulator; + using ElementCompute = typename B2bConv2d::ElementCompute; + + static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + cutlass::HostTensor tensor_A0; + cutlass::HostTensor tensor_B0; + cutlass::HostTensor tensor_C0; + cutlass::HostTensor tensor_Scale0; + cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_Z0_reference; + cutlass::HostTensor tensor_D0_reference; + + cutlass::HostTensor tensor_B1; + cutlass::HostTensor tensor_C1; + cutlass::HostTensor tensor_Bias1; + cutlass::HostTensor tensor_D1_computed; + cutlass::HostTensor tensor_D1_reference; + + +public: + + B2bFusedConv2dRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 16) { + scope = 2; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + ElementCompute alpha0, + ElementCompute alpha1, + uint64_t seed = 2019) { + + tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.resize({1, problem_size_0.K}); + tensor_Bias0.resize({1, problem_size_0.K}); + tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); + tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + + initialize_tensor(tensor_A0.host_view(), init_A, seed); + initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); + if(alpha0 == ElementCompute(0)) //per-channel scale + initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61); + initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); + initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); + initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); + initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0_reference.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1_computed.sync_device(); + tensor_D1_reference.sync_device(); + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + initialize(problem_size_0, problem_size_1, alpha0, alpha1); + + // configure the operator + B2bConv2d b2b_conv2d_op; + + typename B2bConv2d::Arguments b2b_conv2d_args( + problem_size_0, + problem_size_1, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_C0.device_ref(), + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)}, + tensor_D1_computed.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + split_k_mode + ); + + cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n" + << " problem_size_0.K = problem_size_1.C\n" + << " problem_size_1.R = problem_size_1.S = 1\n" + << " ThreadblockShape0::kN = problem_size_0.K\n" + << " ThreadblockShape1::kN = problem_size_1.K" << std::endl; + } + + CUTLASS_CHECK(status); + + status = b2b_conv2d_op.initialize(b2b_conv2d_args); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_conv2d_op(); + CUTLASS_CHECK(status); + } + + // + // Run the Conv2d + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + + // run conv2d operator + status = b2b_conv2d_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float conv2dTime; + cudaEventElapsedTime(&conv2dTime, start, stop); + std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n"; + + tensor_D1_computed.sync_host(); + + bool passed = false; + + cutlass::reference::device::Conv2d< + typename B2bConv2d::ElementA, + typename B2bConv2d::LayoutA, + typename B2bConv2d::ElementB, + typename B2bConv2d::LayoutB, + ElementAccumulator, + typename B2bConv2d::LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_Z0_reference.device_ref(), + tensor_Z0_reference.device_ref(), + ElementAccumulator(1), // intermediate alpha = 1 + ElementAccumulator(0) // beta = 0 + ); + + cutlass::reference::device::TensorScaleBiasConv2d< + ElementAccumulator, + typename B2bConv2d::ElementC, + typename B2bConv2d::LayoutC, + ElementCompute, + typename B2bConv2d::LayoutScaleBias + >( + problem_size_0, + tensor_Z0_reference.device_ref(), + tensor_D0_reference.device_ref(), + alpha0, + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); + } + + cutlass::reference::device::Conv2d< + typename B2bConv2d::ElementA, + typename B2bConv2d::LayoutA, + typename B2bConv2d::ElementB, + typename B2bConv2d::LayoutB, + typename B2bConv2d::ElementC, + typename B2bConv2d::LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size_1, + tensor_D0_reference.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)}, + tensor_D1_reference.device_ref(), + alpha1, + beta1); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); + } + + cudaError_t result = cudaDeviceSynchronize(); + CHECK_TRUE(result == cudaSuccess); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D0_reference.sync_host(); + tensor_D1_reference.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); + + passed = cutlass::reference::host::TensorEquals( + tensor_D1_computed.host_view(), + tensor_D1_reference.host_view()); + + CHECK_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_B2bImplicitGemm_device_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream results(fname.str()); + + results << problem_size_0 << std::endl; + results << problem_size_1 << std::endl; + + results + << "\nA0:\n" << tensor_A0.host_view() << "\n" + << "\nB0:\n" << tensor_B0.host_view() << "\n" + << "\nC0:\n" << tensor_C0.host_view() << "\n" + << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1:\n" << tensor_B1.host_view() << "\n" + << "\nC1:\n" << tensor_C1.host_view() << "\n" + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" + << "\nD1 computed:\n" << tensor_D1_computed.host_view(); + + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h new file mode 100644 index 0000000000000000000000000000000000000000..f0e85cda3a4f36e3f20bbfeb8642ec0c62339a25 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h @@ -0,0 +1,763 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct B2bNonFusedGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bNonFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); + + cutlass::HostTensor< + ElementCompute, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); + + cutlass::HostTensor< + ElementCompute, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm0::Arguments arguments_0{ + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + tensor_D0.device_ref(), + {alpha0, beta0} + }; + + typename Gemm1::Arguments arguments_1{ + problem_size_1, + tensor_D0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + {alpha1, beta1} + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + cutlass::Status status = gemm_op_0.initialize(arguments_0); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = gemm_op_0(); + CUTLASS_CHECK(status); + status = gemm_op_1(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size_0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + // Wait for kernels to finish + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_B2bGemm_device_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed; + } +}; + +template +struct B2bFusedGemmRun +{ + + using B2bGemm = B2bGemm_; + using ElementAccumulator = typename B2bGemm::ElementAccumulator; + using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm, + + // batch_count is used as split-k when mode is kGemm according + // to the GemmUniversal interface + + int batch_count = 1, + int64_t batch_stride_A0 = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_C0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C1 = 0, + int64_t batch_stride_D1 = 0, + int64_t batch_stride_Bias0 = 0, + int64_t batch_stride_Scale0 = 0, + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k()); + cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k()); + + cutlass::HostTensor< + typename B2bGemm::ElementA, + typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Scale0; + + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.resize({1, batch_count * problem_size_0.n()}); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()}); + + cutlass::HostTensor< + ElementAccumulator, + typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()}); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D1(CoordC1.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + if(alpha0 == ElementCompute(0)) //per-channel scale + CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.sync_device(); + tensor_Bias0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + typename B2bGemm::Arguments arguments{ + mode, + problem_size_0, + problem_size_1, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_C0.device_ref(), + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + batch_stride_A0, + batch_stride_B0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1, + batch_stride_Bias0, + batch_stride_Scale0, + {alpha0, beta0}, + {alpha1, beta1}, + batch_count, + }; + + B2bGemm b2b_gemm_op; + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.M = problem_size_1.M\n" + << " problem_size_0.N = problem_size_1.K\n" + << " ThreadblockShape0::kN = problem_size_0.N\n" + << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; + } + + status = b2b_gemm_op.initialize(arguments); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + + tensor_D1.sync_host(); + + // + // Verify + // + + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + + problem_size_0, + ElementAccumulator(1), //intermediate alpha=1 + tensor_A0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B0.device_ref(), + cutlass::ComplexTransform::kNone, + ElementAccumulator(0), //beta = 0 + reference_Z0.device_ref(), + reference_Z0.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_A0, + batch_stride_B0, + batch_stride_C0, + batch_stride_C0 + ); + + cutlass::reference::device::TensorScaleBiasGemmBatched< + ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, typename B2bGemm::LayoutScaleBias + > ( + problem_size_0, + reference_Z0.device_ref(), + reference_D0.device_ref(), + alpha0, + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + int(batch_count), + batch_stride_C0, + batch_stride_C0, + batch_stride_Scale0, + batch_stride_Bias0 + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size_1, + alpha1, //intermediate alpha=1 + reference_D0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B1.device_ref(), + cutlass::ComplexTransform::kNone, + beta1, //beta = 0 + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, + reference_D1.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_C0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1 + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) + { + + std::stringstream fname; + + fname << "error_B2bGemm_device_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h new file mode 100644 index 0000000000000000000000000000000000000000..b6267a153b6ed399120585f885aaa33c24af9e24 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Containers for running grouped back-to-back GEMMs +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/util/device_memory.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct B2bFusedGroupedGemmRun +{ + + using B2bGemm = B2bGemm_; + using ElementAccumulator = typename B2bGemm::ElementAccumulator; + using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bFusedGroupedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 1, -1, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + /// Executes one test + bool run( + std::vector problem_sizes_0, + std::vector problem_sizes_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + using HostTensorA = cutlass::HostTensor; + using HostTensorB = cutlass::HostTensor; + using HostTensorC = cutlass::HostTensor; + using HostTensorScale = cutlass::HostTensor; + using HostTensorZ = cutlass::HostTensor; + using HostTensorBias = cutlass::HostTensor; + + int problem_count = (int)problem_sizes_0.size(); + + std::vector host_tensor_A0(problem_count); + std::vector host_tensor_B0(problem_count); + std::vector host_tensor_C0(problem_count); + std::vector host_tensor_Scale0(problem_count); + std::vector host_tensor_Bias0(problem_count); + std::vector host_tensor_B1(problem_count); + std::vector host_tensor_C1(problem_count); + std::vector host_tensor_Bias1(problem_count); + std::vector host_tensor_D1(problem_count); + std::vector host_tensor_Z(problem_count); + std::vector host_tensor_ref_D0(problem_count); + std::vector host_tensor_ref_D1(problem_count); + + std::vector ref_A0(problem_count); + std::vector ref_B0(problem_count); + std::vector ref_C0(problem_count); + std::vector ref_Scale0(problem_count); + std::vector ref_Bias0(problem_count); + std::vector ref_B1(problem_count); + std::vector ref_C1(problem_count); + std::vector ref_Bias1(problem_count); + std::vector ref_D1(problem_count); + std::vector ref_Z(problem_count); + std::vector ref_ref_D0(problem_count); + std::vector ref_ref_D1(problem_count); + + for (int i = 0; i < problem_count; ++i) { + // + // Allocate the GEMM workspace + // + + auto problem_size_0 = problem_sizes_0[i]; + auto problem_size_1 = problem_sizes_1[i]; + + host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk()); + host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn()); + host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn()); + if (alpha0 == ElementCompute(0)) //per-channel scale + host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()}); + host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()}); + host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn()); + host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn()); + host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn()); + host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn()); + host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()}); + host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn()); + host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn()); + + CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017)); + if (alpha0 == ElementCompute(0)) //per-channel scale + CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014)); + CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013)); + CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + host_tensor_D1.at(i).host_view()); + cutlass::reference::host::TensorFill( + host_tensor_ref_D0.at(i).host_view()); + cutlass::reference::host::TensorFill( + host_tensor_ref_D1.at(i).host_view()); + + host_tensor_A0.at(i).sync_device(); + host_tensor_B0.at(i).sync_device(); + host_tensor_C0.at(i).sync_device(); + if (alpha0 == ElementCompute(0)) //per-channel scale + host_tensor_Scale0.at(i).sync_device(); + host_tensor_Bias0.at(i).sync_device(); + host_tensor_B1.at(i).sync_device(); + host_tensor_C1.at(i).sync_device(); + host_tensor_Bias1.at(i).sync_device(); + host_tensor_D1.at(i).sync_device(); + host_tensor_ref_D0.at(i).sync_device(); + host_tensor_ref_D1.at(i).sync_device(); + + ref_A0.at(i) = (host_tensor_A0.at(i).device_ref()); + ref_B0.at(i) = (host_tensor_B0.at(i).device_ref()); + ref_C0.at(i) = (host_tensor_C0.at(i).device_ref()); + if (alpha0 == ElementCompute(0)) //per-channel scale + ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref()); + ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref()); + ref_B1.at(i) = (host_tensor_B1.at(i).device_ref()); + ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)}; + ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref()); + ref_D1.at(i) = (host_tensor_D1.at(i).device_ref()); + ref_Z.at(i) = (host_tensor_Z.at(i).device_ref()); + ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref()); + ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref()); + } + + // + // Initialize the GEMM operator + // + + cutlass::DeviceAllocation device_ref_A0(problem_count); + device_ref_A0.copy_from_host(ref_A0.data()); + cutlass::DeviceAllocation device_ref_B0(problem_count); + device_ref_B0.copy_from_host(ref_B0.data()); + cutlass::DeviceAllocation device_ref_C0(problem_count); + device_ref_C0.copy_from_host(ref_C0.data()); + cutlass::DeviceAllocation device_ref_Scale0(problem_count); + device_ref_Scale0.copy_from_host(ref_Scale0.data()); + cutlass::DeviceAllocation device_ref_Bias0(problem_count); + device_ref_Bias0.copy_from_host(ref_Bias0.data()); + cutlass::DeviceAllocation device_ref_B1(problem_count); + device_ref_B1.copy_from_host(ref_B1.data()); + cutlass::DeviceAllocation device_ref_C1(problem_count); + device_ref_C1.copy_from_host(ref_C1.data()); + cutlass::DeviceAllocation device_ref_Bias1(problem_count); + device_ref_Bias1.copy_from_host(ref_Bias1.data()); + cutlass::DeviceAllocation device_ref_D1(problem_count); + device_ref_D1.copy_from_host(ref_D1.data()); + + cutlass::DeviceAllocation device_problem_sizes_0(problem_count); + device_problem_sizes_0.copy_from_host(problem_sizes_0.data()); + cutlass::DeviceAllocation device_problem_sizes_1(problem_count); + device_problem_sizes_1.copy_from_host(problem_sizes_1.data()); + + B2bGemm b2b_gemm_op; + + int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count); + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; + return false; + } + + typename B2bGemm::Arguments arguments{ + problem_count, + device_problem_sizes_0.get(), + device_problem_sizes_1.get(), + device_ref_A0.get(), + device_ref_B0.get(), + device_ref_C0.get(), + device_ref_Scale0.get(), + device_ref_Bias0.get(), + device_ref_B1.get(), + device_ref_C1.get(), + device_ref_D1.get(), + {alpha0, beta0}, + {alpha1, beta1}, + threadblock_count + }; + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.M = problem_size_1.M\n" + << " problem_size_0.N = problem_size_1.K\n" + << " ThreadblockShape0::kN = problem_size_0.N\n" + << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; + } + + status = b2b_gemm_op.initialize(arguments); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + + for (int i = 0; i < problem_count; ++i) { + host_tensor_D1.at(i).sync_host(); + + // + // Verify + // + + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, + ElementAccumulator> + reference_gemm_1; + + auto problem_size_0 = problem_sizes_0[i]; + auto problem_size_1 = problem_sizes_1[i]; + + reference_gemm_0( + problem_size_0, + ElementAccumulator(1), //intermediate alpha=1 + ref_A0.at(i), + ref_B0.at(i), + ElementAccumulator(0), //beta = 0 + ref_Z.at(i), + ref_Z.at(i), + ElementAccumulator(0) + ); + + cutlass::reference::device::TensorScaleBiasGemm< + ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, typename B2bGemm::LayoutC + > ( + problem_size_0, + ref_Z.at(i), + ref_ref_D0.at(i), + alpha0, + ref_Scale0.at(i), + ref_Bias0.at(i) + ); + + if(relu) { + cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + ref_ref_D0.at(i), + ref_B1.at(i), + beta1, + {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)}, + ref_ref_D1.at(i) + ); + if(relu) { + cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view()); + } + cudaDeviceSynchronize(); + host_tensor_ref_D0.at(i).sync_host(); + host_tensor_ref_D1.at(i).sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + host_tensor_ref_D1.at(i).host_view(), + host_tensor_D1.at(i).host_view()); + + CHECK_TRUE(passed); + if (!passed) + { + + std::stringstream fname; + + fname << "error_B2bGemm_device_fused.txt"; + std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "GEMM " << i << " in group\n" + << "A0 =\n" << host_tensor_A0.at(i).host_view() + << "\nB0 =\n" << host_tensor_B0.at(i).host_view() + << "\nC0 =\n" << host_tensor_C0.at(i).host_view() + << "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n" + << "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n" + << "\nB1 =\n" << host_tensor_B1.at(i).host_view() + << "\nC1 =\n" << host_tensor_C1.at(i).host_view() + << "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n" + << "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view() + << "\nComputed =\n" << host_tensor_D1.at(i).host_view(); + + return false; + } + } + return true; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h new file mode 100644 index 0000000000000000000000000000000000000000..4693e86423d9cdc142ee8bc84348e706c1c83280 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h @@ -0,0 +1,749 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/host_reorder.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + + +template +class B2bInterleavedNonFusedConv2dRun { +public: + + using Conv2d0 = Conv2d0_; + using Conv2d1 = Conv2d1_; + using ElementAccumulator = typename Conv2d0::ElementAccumulator; + using ElementCompute = typename Conv2d0::ElementCompute; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator; + static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator, + "Fused convolution operators must be the same"); + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + cutlass::HostTensor tensor_A0; + cutlass::HostTensor tensor_B0; + cutlass::HostTensor tensor_B0_reordered; + cutlass::HostTensor tensor_C0; + cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_D0_computed; + cutlass::HostTensor tensor_D0_reference; + + cutlass::HostTensor tensor_B1; + cutlass::HostTensor tensor_B1_reordered; + cutlass::HostTensor tensor_C1; + cutlass::HostTensor tensor_Bias1; + cutlass::HostTensor tensor_D1_computed; + cutlass::HostTensor tensor_D1_reference; + + +public: + + B2bInterleavedNonFusedConv2dRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 16) { + scope = 2; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, uint64_t seed = 2019) { + + tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_Bias0.resize({1, 1, 1, problem_size_0.K}); + tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); + tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + + initialize_tensor(tensor_A0.host_view(), init_A, seed); + initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); + initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); + initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); + + //Reorder B0 and B1 + cutlass::reorder_convK( + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0)); + cutlass::reorder_convK( + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1)); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_B0_reordered.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0_computed.sync_device(); + tensor_D0_reference.sync_device(); + tensor_B1.sync_device(); + tensor_B1_reordered.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1_computed.sync_device(); + tensor_D1_reference.sync_device(); + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + initialize(problem_size_0, problem_size_1); + + // configure the operator + Conv2d0 conv2d_op_0; + Conv2d1 conv2d_op_1; + + typename Conv2d0::Arguments conv2d_args_0( + problem_size_0, + tensor_A0.device_ref(), + tensor_B0_reordered.device_ref(), + tensor_C0.device_ref(), + tensor_D0_computed.device_ref(), + {alpha0, beta0}, + split_k_mode + ); + typename Conv2d1::Arguments conv2d_args_1( + problem_size_1, + tensor_D0_computed.device_ref(), + tensor_B1_reordered.device_ref(), + tensor_C1.device_ref(), + tensor_D1_computed.device_ref(), + {alpha1, beta1}, + split_k_mode + ); + + + cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0); + + CUTLASS_CHECK(status); + + status = conv2d_op_1.initialize(conv2d_args_1); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = conv2d_op_0(); + CUTLASS_CHECK(status); + status = conv2d_op_1(); + CUTLASS_CHECK(status); + } + + // + // Run Conv2d + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + + for(int i = 0; i < runs; i++) { + // run conv2d operator + status = conv2d_op_0(); + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + + for(int i = 0; i < runs; i++) { + // run conv2d operator + status = conv2d_op_1(); + CUTLASS_CHECK(status); + } + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float conv2d0Time, conv2d1Time, totalTime; + cudaEventElapsedTime(&conv2d0Time, start, stop1); + cudaEventElapsedTime(&conv2d1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n"; + std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; + + tensor_D0_computed.sync_host(); + tensor_D1_computed.sync_host(); + + bool passed = false; + + cutlass::reference::device::Conv2d< + typename Conv2d0::ElementA, + typename Conv2d0::LayoutA, + typename Conv2d0::ElementB, + typename Conv2d0::LayoutB, + typename Conv2d0::ElementC, + typename Conv2d0::LayoutC, + ElementCompute, + ElementAccumulator, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_C0.device_ref(), + tensor_D0_reference.device_ref(), + alpha0, + beta0); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); + } + + cutlass::reference::device::Conv2d< + typename Conv2d1::ElementA, + typename Conv2d1::LayoutA, + typename Conv2d1::ElementB, + typename Conv2d1::LayoutB, + typename Conv2d1::ElementC, + typename Conv2d1::LayoutC, + ElementCompute, + ElementAccumulator, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size_1, + tensor_D0_reference.device_ref(), + tensor_B1.device_ref(), + tensor_C1.device_ref(), + tensor_D1_reference.device_ref(), + alpha1, + beta1); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); + } + + cudaError_t result = cudaDeviceSynchronize(); + CHECK_TRUE(result == cudaSuccess); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D0_reference.sync_host(); + tensor_D1_reference.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); + + passed = cutlass::reference::host::TensorEquals( + tensor_D1_computed.host_view(), + tensor_D1_reference.host_view()); + + CHECK_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_B2bImplicitGemm_device_interleaved_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream results(fname.str()); + + results << problem_size_0 << std::endl; + results << problem_size_1 << std::endl; + + results + << "\nA0:\n" << tensor_A0.host_view() << "\n" + << "\nB0:\n" << tensor_B0.host_view() << "\n" + << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n" + << "\nC0:\n" << tensor_C0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n" + << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n" + << "\nB1:\n" << tensor_B1.host_view() << "\n" + << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" + << "\nC1:\n" << tensor_C1.host_view() << "\n" + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" + << "\nD1 computed:\n" << tensor_D1_computed.host_view(); + + + } + + return passed; + } + +}; + +template +class B2bInterleavedFusedConv2dRun { +public: + + using B2bConv2d = B2bConv2d_; + using ElementAccumulator = typename B2bConv2d::ElementAccumulator; + using ElementCompute = typename B2bConv2d::ElementCompute; + + static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + cutlass::HostTensor tensor_A0; + cutlass::HostTensor tensor_B0; + cutlass::HostTensor tensor_B0_reordered; + cutlass::HostTensor tensor_C0; + cutlass::HostTensor tensor_Scale0; + cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_Z0_reference; + cutlass::HostTensor tensor_D0_reference; + + cutlass::HostTensor tensor_B1; + cutlass::HostTensor tensor_B1_reordered; + cutlass::HostTensor tensor_C1; + cutlass::HostTensor tensor_Bias1; + cutlass::HostTensor tensor_D1_computed; + cutlass::HostTensor tensor_D1_reference; + + +public: + + B2bInterleavedFusedConv2dRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 16) { + scope = 2; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + ElementCompute alpha0, + ElementCompute alpha1, + uint64_t seed = 2019) { + + tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.resize({1, problem_size_0.K}); + tensor_Bias0.resize({1, problem_size_0.K}); + tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); + tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + + initialize_tensor(tensor_A0.host_view(), init_A, seed); + initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); + if(alpha0 == ElementCompute(0)) //per-channel scale + initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61); + initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); + initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); + initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); + initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); + + //Reorder B0 and B1 + cutlass::reorder_convK<16, InterleavedK>( + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0)); + cutlass::reorder_convK( + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1)); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_B0_reordered.sync_device(); + tensor_C0.sync_device(); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0_reference.sync_device(); + tensor_B1.sync_device(); + tensor_B1_reordered.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1_computed.sync_device(); + tensor_D1_reference.sync_device(); + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + initialize(problem_size_0, problem_size_1, alpha0, alpha1); + + // configure the operator + B2bConv2d b2b_conv2d_op; + + typename B2bConv2d::Arguments b2b_conv2d_args( + problem_size_0, + problem_size_1, + tensor_A0.device_ref(), + tensor_B0_reordered.device_ref(), + tensor_C0.device_ref(), + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + tensor_B1_reordered.device_ref(), + tensor_C1.device_ref(), + tensor_D1_computed.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + split_k_mode + ); + + cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n" + << " problem_size_0.K = problem_size_1.C\n" + << " problem_size_1.R = problem_size_1.S = 1\n" + << " ThreadblockShape0::kN = problem_size_0.K\n" + << " ThreadblockShape1::kN = problem_size_1.K" << std::endl; + } + + CUTLASS_CHECK(status); + + status = b2b_conv2d_op.initialize(b2b_conv2d_args); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_conv2d_op(); + CUTLASS_CHECK(status); + } + + // + // Run the Conv2d + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + + // run conv2d operator + status = b2b_conv2d_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float conv2dTime; + cudaEventElapsedTime(&conv2dTime, start, stop); + std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n"; + + tensor_D1_computed.sync_host(); + + bool passed = false; + + cutlass::reference::device::Conv2d< + typename B2bConv2d::ElementA, + typename B2bConv2d::LayoutA, + typename B2bConv2d::ElementB, + typename B2bConv2d::LayoutB, + ElementAccumulator, + typename B2bConv2d::LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_Z0_reference.device_ref(), + tensor_Z0_reference.device_ref(), + ElementAccumulator(1), // intermediate alpha = 1 + ElementAccumulator(0) // beta = 0 + ); + + cutlass::reference::device::TensorScaleBiasConv2d< + ElementAccumulator, + typename B2bConv2d::ElementC, + typename B2bConv2d::LayoutC, + ElementCompute, + typename B2bConv2d::LayoutScaleBias, + cutlass::NumericConverterClamp + >( + problem_size_0, + tensor_Z0_reference.device_ref(), + tensor_D0_reference.device_ref(), + alpha0, + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); + } + + cutlass::reference::device::Conv2d< + typename B2bConv2d::ElementA, + typename B2bConv2d::LayoutA, + typename B2bConv2d::ElementB, + typename B2bConv2d::LayoutB, + typename B2bConv2d::ElementC, + typename B2bConv2d::LayoutC, + ElementCompute, + ElementAccumulator, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size_1, + tensor_D0_reference.device_ref(), + tensor_B1.device_ref(), + tensor_C1.device_ref(), + tensor_D1_reference.device_ref(), + alpha1, + beta1); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); + } + + cudaError_t result = cudaDeviceSynchronize(); + CHECK_TRUE(result == cudaSuccess); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D0_reference.sync_host(); + tensor_D1_reference.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); + + passed = cutlass::reference::host::TensorEquals( + tensor_D1_computed.host_view(), + tensor_D1_reference.host_view()); + + CHECK_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_B2bImplicitGemm_device_interleaved_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream results(fname.str()); + + results << problem_size_0 << std::endl; + results << problem_size_1 << std::endl; + + results + << "\nA0:\n" << tensor_A0.host_view() << "\n" + << "\nB0:\n" << tensor_B0.host_view() << "\n" + << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n" + << "\nC0:\n" << tensor_C0.host_view() << "\n" + << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1:\n" << tensor_B1.host_view() << "\n" + << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" + << "\nC1:\n" << tensor_C1.host_view() << "\n" + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" + << "\nD1 computed:\n" << tensor_D1_computed.host_view(); + + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h new file mode 100644 index 0000000000000000000000000000000000000000..453f44cd0c44f9320085460935a3478bc884ff0e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h @@ -0,0 +1,798 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +template +struct B2bInterleavedNonFusedGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bInterleavedNonFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); + + //Reorder B0 and B1 + cutlass::reorder_column( + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); + cutlass::reorder_column( + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_B0_reordered.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0.sync_device(); + tensor_B1.sync_device(); + tensor_B1_reordered.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm0::Arguments arguments_0{ + problem_size_0, + tensor_A0.device_ref(), + tensor_B0_reordered.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + tensor_D0.device_ref(), + {alpha0, beta0} + }; + + typename Gemm1::Arguments arguments_1{ + problem_size_1, + tensor_D0.device_ref(), + tensor_B1_reordered.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + {alpha1, beta1} + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + cutlass::Status status = gemm_op_0.initialize(arguments_0); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = gemm_op_0(); + CUTLASS_CHECK(status); + status = gemm_op_1(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size_0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + // Wait for kernels to finish + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_B2bGemm_device_interleaved_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed; + } +}; + +template +struct B2bInterleavedFusedGemmRun +{ + + using B2bGemm = B2bGemm_; + using ElementAccumulator = typename B2bGemm::ElementAccumulator; + using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bInterleavedFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm, + + // batch_count is used as split-k when mode is kGemm according + // to the GemmUniversal interface + + int batch_count = 1, + + int64_t batch_stride_A0 = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_C0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C1 = 0, + int64_t batch_stride_D1 = 0, + int64_t batch_stride_Bias0 = 0, + int64_t batch_stride_Scale0 = 0, + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k()); + cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k()); + + cutlass::HostTensor< + typename B2bGemm::ElementA, + typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Scale0; + + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.resize({1, batch_count * problem_size_0.n()}); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()}); + + cutlass::HostTensor< + ElementAccumulator, + typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()}); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D1(CoordC1.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + if(alpha0 == ElementCompute(0)) //per-channel scale + CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); + + //Reorder B0 + cutlass::reorder_column<16>( + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0); + cutlass::reorder_column( + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1); + + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_B0_reordered.sync_device(); + tensor_C0.sync_device(); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.sync_device(); + tensor_Bias0.sync_device(); + tensor_B1.sync_device(); + tensor_B1_reordered.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + // tensor_Bias0_batched.sync_device(); + + // + // Initialize the GEMM operator + // + + typename B2bGemm::Arguments arguments{ + mode, + problem_size_0, + problem_size_1, + tensor_A0.device_ref(), + tensor_B0_reordered.device_ref(), + tensor_C0.device_ref(), + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + tensor_B1_reordered.device_ref(), + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + batch_stride_A0, + batch_stride_B0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1, + batch_stride_Bias0, + batch_stride_Scale0, + {alpha0, beta0}, + {alpha1, beta1}, + batch_count, + }; + + B2bGemm b2b_gemm_op; + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.M = problem_size_1.M\n" + << " problem_size_0.N = problem_size_1.K\n" + << " ThreadblockShape0::kN = problem_size_0.N\n" + << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; + } + + status = b2b_gemm_op.initialize(arguments); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + + tensor_D1.sync_host(); + + // + // Verify + // + + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size_0, + ElementAccumulator(1), //intermediate alpha=1 + tensor_A0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B0.device_ref(), + cutlass::ComplexTransform::kNone, + ElementAccumulator(0), //beta = 0 + reference_Z0.device_ref(), + reference_Z0.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_A0, + batch_stride_B0, + batch_stride_C0, + batch_stride_C0 + ); + + cutlass::reference::device::TensorScaleBiasGemmBatched< + ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, typename B2bGemm::LayoutScaleBias + > ( + problem_size_0, + reference_Z0.device_ref(), + reference_D0.device_ref(), + alpha0, + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + int(batch_count), + batch_stride_C0, + batch_stride_C0, + batch_stride_Scale0, + batch_stride_Bias0 + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size_1, + alpha1, //intermediate alpha=1 + reference_D0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B1.device_ref(), + cutlass::ComplexTransform::kNone, + beta1, //beta = 0 + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, + reference_D1.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_C0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1 + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) + { + + std::stringstream fname; + + fname << "error_B2bGemm_device_interleaved_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1 =\n" << tensor_B1.host_view() + << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..f9b2f49cd67708e655b17fdae086a731db66571a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h @@ -0,0 +1,352 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" + +#include "kernel/b2b_gemm.h" +#include "kernel/default_b2b_gemm.h" +#include "kernel/default_b2b_gemm_smem_accumulator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Epilogue output operator + typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Stage accumulator in shared memory + bool SmemAccumulator = false, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class B2bGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape0 = ThreadblockShape0_; + using ThreadblockShape1 = ThreadblockShape1_; + using WarpShape0 = WarpShape0_; + using WarpShape1 = WarpShape1_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp0 = EpilogueOutputOp0_; + using EpilogueOutputOp1 = EpilogueOutputOp1_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp1::kCount; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Derived types + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; + + /// Define the kernel + using B2bGemmKernel = typename kernel::DefaultB2bGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + kStages, + Operator, + SmemAccumulator + >::B2bGemmKernel; + + using Arguments = typename B2bGemmKernel::Arguments; + +private: + + /// Kernel parameters object + typename B2bGemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + B2bGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + Status status = B2bGemmKernel::can_implement( + args.problem_size_0, + args.problem_size_1, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1 + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size_0, + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.batch_count); + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size_0, + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.batch_count); +// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape( +// args.problem_size_1, +// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK}, +// args.batch_count); + + // Initialize the Params structure + params_ = typename B2bGemmKernel::Params{ + args.mode, + args.problem_size_0, + args.problem_size_1, + grid_shape, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_Scale0.non_const_ref(), + args.ref_Bias0.non_const_ref(), + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.batch_stride_A0, + args.batch_stride_B0, + args.batch_stride_B1, + args.batch_stride_C1, + args.batch_stride_D1, + args.batch_stride_Bias0, + args.batch_stride_Scale0, + args.epilogue0, + args.epilogue1, + static_cast(workspace), + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); + params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); + params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); + params_.ref_Scale0.reset(args.ref_Scale0.non_const_ref().data()); + params_.ref_Bias0.reset(args.ref_Bias0.non_const_ref().data()); + params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); + params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); + params_.ref_D1.reset(args.ref_D1.data()); + params_.output_op_0 = args.epilogue0; + params_.output_op_1 = args.epilogue1; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(B2bGemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..e780537aff4cdc733a66044c771eb7bcc6a7ec03 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Template for device-level Implicit GEMM +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/conv/convolution.h" + +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "kernel/default_b2b_conv2d_fprop.h" +#include "kernel/default_b2b_conv2d_fprop_sm75.h" +#include "kernel/default_b2b_conv2d_fprop_sm80.h" +#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h" +#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h" + +namespace cutlass { +namespace conv { +namespace device { + +template +class B2bImplicitGemmConvolution { +public: + + using B2bImplicitGemmKernel = B2bImplicitGemmKernel_; + + using ElementA = typename B2bImplicitGemmKernel::ElementA; + using LayoutA = typename B2bImplicitGemmKernel::LayoutA; + using ElementB = typename B2bImplicitGemmKernel::ElementB; + using LayoutB = typename B2bImplicitGemmKernel::LayoutB; + using ElementC = typename B2bImplicitGemmKernel::ElementC; + using LayoutC = typename B2bImplicitGemmKernel::LayoutC; + using ElementAccumulator = typename B2bImplicitGemmKernel::ElementAccumulator; + using ElementCompute = typename B2bImplicitGemmKernel::ElementCompute; + using ElementScaleBias = typename B2bImplicitGemmKernel::ElementScaleBias; + using LayoutScaleBias = typename B2bImplicitGemmKernel::LayoutScaleBias; + using OperatorClass = typename B2bImplicitGemmKernel::OperatorClass; + using ArchTag = typename B2bImplicitGemmKernel::ArchTag; + using ThreadblockShape0 = typename B2bImplicitGemmKernel::ThreadblockShape0; + using ThreadblockShape1 = typename B2bImplicitGemmKernel::ThreadblockShape1; + using WarpShape0 = typename B2bImplicitGemmKernel::WarpShape0; + using WarpShape1 = typename B2bImplicitGemmKernel::WarpShape1; + using InstructionShape = typename B2bImplicitGemmKernel::InstructionShape; + using ThreadblockSwizzle = typename B2bImplicitGemmKernel::ThreadblockSwizzle; + using EpilogueOutputOp0 = typename B2bImplicitGemmKernel::EpilogueOutputOp0; + using EpilogueOutputOp1 = typename B2bImplicitGemmKernel::EpilogueOutputOp1; + static int const kStages = B2bImplicitGemmKernel::kStages; + static int const kConvDim = B2bImplicitGemmKernel::kConvDim; + using WarpMmaOperator0 = typename B2bImplicitGemmKernel::WarpMmaOperator0; + using WarpMmaOperator1 = typename B2bImplicitGemmKernel::WarpMmaOperator1; + using ArchMmaOperator = typename B2bImplicitGemmKernel::ArchMmaOperator; + using MathOperator = typename B2bImplicitGemmKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = B2bImplicitGemmKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = B2bImplicitGemmKernel::kIteratorAlgorithm; + + static int const kWarpCount = + (ThreadblockShape0::kM / WarpShape0::kM) * + (ThreadblockShape0::kN / WarpShape0::kN); + + /// Argument structure + using Arguments = typename B2bImplicitGemmKernel::Arguments; + +private: + + /// Kernel parameters object + typename B2bImplicitGemmKernel::Params params_; + +public: + + /// Constructs Implicit GEMM + B2bImplicitGemmConvolution() { } + + /// Determines whether the Implicit GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + // dispatch to iterators + Status status = B2bImplicitGemmKernel::B2bMma::IteratorA0::can_implement(args.problem_size_0); + if (Status::kSuccess != status) { + return status; + } + + status = B2bImplicitGemmKernel::B2bMma::IteratorB0::can_implement(args.problem_size_0); + if (Status::kSuccess != status) { + return status; + } + + status = B2bImplicitGemmKernel::B2bMma::IteratorB1::can_implement(args.problem_size_1); + if (Status::kSuccess != status) { + return status; + } + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape( + threadblock_swizzle.get_tiled_shape( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0), + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.problem_size_0.split_k_slices)); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } + + // Determine if fusion sizes are valid + + cutlass::gemm::GemmCoord problem_size_0 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0); + cutlass::gemm::GemmCoord problem_size_1 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1); + + if(problem_size_0.m() != problem_size_1.m()) + return Status::kErrorInvalidProblem; + + if(problem_size_0.n() != problem_size_1.k()) + return Status::kErrorInvalidProblem; + + if(args.problem_size_1.R != 1 || args.problem_size_1.S != 1) + return Status::kErrorInvalidProblem; + + if(problem_size_0.n() > ThreadblockShape0::kN) + return Status::kErrorInvalidProblem; + + if(problem_size_1.n() > ThreadblockShape1::kN) + return Status::kErrorInvalidProblem; + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t workspace_bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0), + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.problem_size_0.split_k_slices); + + if(args.split_k_mode == SplitKMode::kParallel) { + + // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. + // The user needs to call a reduction operator to obtain the final output tensor + workspace_bytes = + sizeof(ElementAccumulator) * + size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) * + size_t(grid_tiled_shape.k()); + } + + else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size_0.split_k_slices > 1) { + + // Split-K serial: The user workspace is used to store semaphore and serialize writing the + // final reduced output to user's output tensor + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + return workspace_bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + if (args.problem_size_0.split_k_slices > 1) { + + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); + + if (status != cudaSuccess) { + return Status::kErrorInternal; + } + } + + // initialize the params structure from the arguments + params_ = typename B2bImplicitGemmKernel::Params( + args, + static_cast(workspace) + ); + + int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Initializes GEMM state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.ptr_A0 = args.ref_A0.data(); + params_.ptr_B0 = args.ref_B0.data(); + params_.ptr_C0 = args.ref_C0.data(); + params_.ptr_Scale0 = args.ref_Scale0.data(); + params_.ptr_Bias0 = args.ref_Bias0.data(); + params_.ptr_B1 = args.ref_B1.data(); + params_.ptr_C1 = args.ref_C1.data(); + params_.ptr_D1 = args.ref_D1.data(); + params_.output_op_0 = args.output_op_0; + params_.output_op_1 = args.output_op_1; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(32 * kWarpCount, 1, 1); + + int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage)); + + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace device +} // namespace conv +} // namespace cutlass +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..ef11f4ab83fb0aff78d7f494bb57df7392474e48 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h @@ -0,0 +1,811 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "kernel/b2b_gemm_grouped_problem_visitor.h" +#include "threadblock/grouped_threadblock_swizzle.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +namespace detail { + +/// Utility struct for returning the type of the problem visitor used by the swizzling function, +/// if it is a grouped swizzling function, or a default visitor. This is used only for defining +/// the parameters of the problem visitor used in GroupedParams. +template < + typename B2bMma_, + typename ThreadblockSwizzle_, + typename Enable = void +> +struct ProblemVisitorOrDefault; + +/// Return a generic problem visitor for GEMM problems +template < + typename B2bMma_, + typename ThreadblockSwizzle_ +> +struct ProblemVisitorOrDefault::value + >::type> { + using value = B2bGemmGroupedProblemVisitor::value>; +}; + +/// Return the problem visitor specified by the swizzling function +template < + typename B2bMma_, + typename ThreadblockSwizzle_ +> +struct ProblemVisitorOrDefault::value + >::type> { + using value = typename ThreadblockSwizzle_::ProblemVisitor; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct B2bGemm { + + using B2bMma = B2bMma_; + using Epilogue = Epilogue_; + using OutputOp0 = typename B2bMma::OutputOp; + using OutputOp1 = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA0 = typename B2bMma::IteratorA0::Element; + using LayoutA0 = typename B2bMma::IteratorA0::Layout; + using ElementB0 = typename B2bMma::IteratorB0::Element; + using LayoutB0 = typename B2bMma::IteratorB0::Layout; + using ElementB1 = typename B2bMma::IteratorB1::Element; + using LayoutB1 = typename B2bMma::IteratorB1::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element; + + /// Data types needed for higher-level containers. In some cases, a single type must be exposed + /// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from + /// the second GEMM (other than for ElementA/ElementB) + using ElementA = typename B2bMma::IteratorA0::Element; + using LayoutA = typename B2bMma::IteratorA0::Layout; + using ElementB = typename B2bMma::IteratorB0::Element; + using LayoutB = typename B2bMma::IteratorB0::Layout; + + static ComplexTransform const kTransformA = B2bMma::kTransformA; + static ComplexTransform const kTransformB = B2bMma::kTransformB; + using Operator = typename B2bMma::Operator0; + + using OperatorClass = typename Operator::OperatorClass; + using ThreadblockShape = typename B2bMma::Shape0; + using WarpShape = typename Operator::Shape; + using InstructionShape = typename Operator::InstructionShape; + using ArchTag = typename B2bMma::ArchTag; + + static int const kStages = B2bMma::kStages; + static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements; + static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + using Mma = B2bMma; + using EpilogueOutputOp = OutputOp1; + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename B2bMma::WarpCount0; + static int const kThreadCount = 32 * WarpCount0::kCount; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; + GemmCoord problem_size_0{0,0,0}; + GemmCoord problem_size_1{0,0,0}; + typename B2bMma::IteratorA0::TensorRef ref_A0{}; + typename B2bMma::IteratorB0::TensorRef ref_B0{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{}; + typename B2bMma::IteratorB1::TensorRef ref_B1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_D1{}; + int64_t batch_stride_A0{0}; + int64_t batch_stride_B0{0}; + int64_t batch_stride_B1{0}; + int64_t batch_stride_C1{0}; + int64_t batch_stride_D1{0}; + int64_t batch_stride_Bias0{0}; + int64_t batch_stride_Scale0{0}; + typename OutputOp0::Params epilogue0 {}; + typename OutputOp1::Params epilogue1 {}; + int batch_count{1}; + + // + // Methods + // + + /// Default ctor + Arguments() = default; + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmUniversalMode mode_, + GemmCoord problem_size_0_, + GemmCoord problem_size_1_, + typename B2bMma::IteratorA0::TensorRef ref_A0_, + typename B2bMma::IteratorB0::TensorRef ref_B0_, + typename Epilogue::OutputTileIterator::TensorRef ref_C0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_, + typename B2bMma::IteratorB1::TensorRef ref_B1_, + typename Epilogue::OutputTileIterator::TensorRef ref_C1_, + typename Epilogue::OutputTileIterator::TensorRef ref_D1_, + int64_t batch_stride_A0_, + int64_t batch_stride_B0_, + int64_t batch_stride_B1_, + int64_t batch_stride_C1_, + int64_t batch_stride_D1_, + int64_t batch_stride_Bias0_, + int64_t batch_stride_Scale0_, + typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(), + typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(), + int batch_count_ = 1 + ): + mode(mode_), + problem_size_0(problem_size_0_), + problem_size_1(problem_size_1_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), + ref_Bias0(ref_Bias0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + batch_stride_A0(batch_stride_A0_), + batch_stride_B0(batch_stride_B0_), + batch_stride_B1(batch_stride_B1_), + batch_stride_C1(batch_stride_C1_), + batch_stride_D1(batch_stride_D1_), + batch_stride_Bias0(batch_stride_Bias0_), + batch_stride_Scale0(batch_stride_Scale0_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + batch_count(batch_count_) { + } + }; + + // Arguments structure for grouped B2B problems + struct GroupedArguments { + GemmCoord* problem_size_0; + GemmCoord* problem_size_1; + typename B2bMma::IteratorA0::TensorRef* ref_A0; + typename B2bMma::IteratorB0::TensorRef* ref_B0; + typename Epilogue::OutputTileIterator::TensorRef* ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0; + typename B2bMma::IteratorB1::TensorRef* ref_B1; + typename Epilogue::OutputTileIterator::TensorRef* ref_C1; + typename Epilogue::OutputTileIterator::TensorRef* ref_D1; + + // Epilogue params remain constant across all problems in the group. Thus, + // the parameter here is not a pointer. + typename OutputOp0::Params epilogue0; + typename OutputOp1::Params epilogue1; + + int problem_count; + int threadblock_count; + GemmCoord* host_problem_sizes; + + CUTLASS_HOST_DEVICE + GroupedArguments( + int problem_count, + GemmCoord* problem_size_0_, + GemmCoord* problem_size_1_, + typename B2bMma::IteratorA0::TensorRef* ref_A0_, + typename B2bMma::IteratorB0::TensorRef* ref_B0_, + typename Epilogue::OutputTileIterator::TensorRef* ref_C0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_, + typename B2bMma::IteratorB1::TensorRef* ref_B1_, + typename Epilogue::OutputTileIterator::TensorRef* ref_C1_, + typename Epilogue::OutputTileIterator::TensorRef* ref_D1_, + typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(), + typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(), + int threadblock_count = 0 + ) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_), + ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_), + ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_), + problem_count(problem_count), + threadblock_count(threadblock_count) + {} + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; + cutlass::gemm::GemmCoord problem_size_0{}; + cutlass::gemm::GemmCoord problem_size_1{}; + cutlass::gemm::GemmCoord grid_tiled_shape{}; + int swizzle_log_tile{0}; + typename B2bMma::IteratorA0::Params params_A0{}; + typename B2bMma::IteratorA0::TensorRef ref_A0{}; + typename B2bMma::IteratorB0::Params params_B0{}; + typename B2bMma::IteratorB0::TensorRef ref_B0{}; + typename Epilogue::OutputTileIterator::Params params_C0{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{}; + typename B2bMma::IteratorB1::Params params_B1{}; + typename B2bMma::IteratorB1::TensorRef ref_B1{}; + typename Epilogue::OutputTileIterator::Params params_C1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C1{}; + typename Epilogue::OutputTileIterator::Params params_D1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_D1{}; + typename OutputOp0::Params output_op_0{}; + typename OutputOp1::Params output_op_1{}; + int64_t batch_stride_A0{0}; + int64_t batch_stride_B0{0}; + int64_t batch_stride_B1{0}; + int64_t batch_stride_C1{0}; + int64_t batch_stride_D1{0}; + int64_t batch_stride_Bias0{0}; + int64_t batch_stride_Scale0{0}; + int *semaphore = nullptr; + int gemm_k_iterations_0{0}; + int gemm_k_size_0{0}; + int gemm_k_iterations_1{0}; + int gemm_k_size_1{0}; + + // + // Methods + // + + Params() = default; + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord const & problem_size_0, + cutlass::gemm::GemmCoord const & problem_size_1, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename B2bMma::IteratorA0::TensorRef ref_A0, + typename B2bMma::IteratorB0::TensorRef ref_B0, + typename Epilogue::OutputTileIterator::TensorRef ref_C0, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0, + typename B2bMma::IteratorB1::TensorRef ref_B1, + typename Epilogue::OutputTileIterator::TensorRef ref_C1, + typename Epilogue::OutputTileIterator::TensorRef ref_D1, + int64_t batch_stride_A0, + int64_t batch_stride_B0, + int64_t batch_stride_B1, + int64_t batch_stride_C1, + int64_t batch_stride_D1, + int64_t batch_stride_Bias0, + int64_t batch_stride_Scale0, + typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), + typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), + int *workspace = nullptr + ): + mode(mode), + problem_size_0(problem_size_0), + problem_size_1(problem_size_1), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)), + params_A0(ref_A0.layout()), + ref_A0(ref_A0), + params_B0(ref_B0.layout()), + ref_B0(ref_B0), + params_C0(ref_C0.layout()), + ref_C0(ref_C0), + ref_Scale0(ref_Scale0), + ref_Bias0(ref_Bias0), + params_B1(ref_B1.layout()), + ref_B1(ref_B1), + params_C1(ref_C1.layout()), + ref_C1(ref_C1), + params_D1(ref_D1.layout()), + ref_D1(ref_D1), + batch_stride_A0(batch_stride_A0), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C1(batch_stride_C1), + batch_stride_D1(batch_stride_D1), + batch_stride_Bias0(batch_stride_Bias0), + batch_stride_Scale0(batch_stride_Scale0), + output_op_0(output_op_0), + output_op_1(output_op_1) { + + int total_gemm_k_iterations_0 = (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; + int gemm_k_iterations_0 = (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK; + int total_gemm_k_iterations_1 = (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + int gemm_k_iterations_1 = (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK; + + semaphore = workspace; + } + }; + + struct GroupedParams { + cutlass::gemm::GemmCoord* problem_size_0; + cutlass::gemm::GemmCoord* problem_size_1; + cutlass::gemm::GemmCoord* grid_tiled_shape; + typename B2bMma::IteratorA0::TensorRef* ref_A0; + typename B2bMma::IteratorB0::TensorRef* ref_B0; + typename Epilogue::OutputTileIterator::TensorRef* ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0; + typename B2bMma::IteratorB1::TensorRef* ref_B1; + typename Epilogue::OutputTileIterator::TensorRef* ref_C1; + typename Epilogue::OutputTileIterator::TensorRef* ref_D1; + + // Epilogue params remain constant across all problems in the group. Thus, + // the parameter here is not a pointer. + typename OutputOp0::Params output_op_0; + typename OutputOp1::Params output_op_1; + + using ProblemVisitor = typename detail::ProblemVisitorOrDefault::value; + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int* workspace; + + CUTLASS_HOST_DEVICE + GroupedParams() {} + + CUTLASS_HOST_DEVICE + GroupedParams( + GroupedArguments const &args, + void *workspace = nullptr, + int tile_count = 0 + ) : + problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1), + ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0), + ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1), + output_op_0(args.epilogue0), output_op_1(args.epilogue1), + problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + workspace(reinterpret_cast(workspace)) {} + + CUTLASS_HOST_DEVICE + void transpose() { + // Only row-major outputs are currently supported, so no transpose is performed + } + + /// Returns non-grouped parameters to be used as input to the kernel-level + /// operator for the problem indicated by problem_visitor. + CUTLASS_HOST_DEVICE + Params to_single_params(const ProblemVisitor& problem_visitor) const { + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + int32_t idx = problem_visitor.problem_index(); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1); + + return Params( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size0, + problem_size1, + grid_shape, + ref_A0[idx], + ref_B0[idx], + ref_C0[idx], + ref_Scale0[idx], + ref_Bias0[idx], + ref_B1[idx], + ref_C1[idx], + ref_D1[idx], + 0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported + output_op_0, + output_op_1, + workspace + ); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename B2bMma::B2bMmaSharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + B2bGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size_0, + cutlass::gemm::GemmCoord const & problem_size_1, + typename B2bMma::IteratorA0::TensorRef ref_A0, + typename B2bMma::IteratorB0::TensorRef ref_B0, + typename Epilogue::OutputTileIterator::TensorRef ref_C0, + typename B2bMma::IteratorB1::TensorRef ref_B1, + typename Epilogue::OutputTileIterator::TensorRef ref_C1, + typename Epilogue::OutputTileIterator::TensorRef ref_D1) { + + static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements; + static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A0, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B0, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B1, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if ((problem_size_0.m() % kAlignmentA) || (problem_size_0.k() % kAlignmentA) || + (problem_size_0.n() % kAlignmentB) || (problem_size_0.k() % kAlignmentB) || + (problem_size_0.m() % kAlignmentC) || (problem_size_0.n() % kAlignmentC) || + (problem_size_1.m() % kAlignmentA) || (problem_size_1.k() % kAlignmentA) || + (problem_size_1.n() % kAlignmentB) || (problem_size_1.k() % kAlignmentB) || + (problem_size_1.m() % kAlignmentC) || (problem_size_1.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + // Determine if fusion sizes are valid + if(problem_size_0.m() != problem_size_1.m()) + return Status::kErrorInvalidProblem; + + if(problem_size_0.n() != problem_size_1.k()) + return Status::kErrorInvalidProblem; + + if(problem_size_0.n() > B2bMma::Shape0::kN) + return Status::kErrorInvalidProblem; + + if(problem_size_1.n() > B2bMma::Shape1::kN) + return Status::kErrorInvalidProblem; + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + ThreadblockSwizzle threadblock_swizzle; + run_with_swizzle(params, shared_storage, threadblock_swizzle); + } + + /// Executes one GEMM with an externally-provided swizzling function + CUTLASS_DEVICE + void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) { + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + ElementA0 *ptr_A0 = static_cast(params.ref_A0.data()); + ElementB0 *ptr_B0 = static_cast(params.ref_B0.data()); + ElementB1 *ptr_B1 = static_cast(params.ref_B1.data()); + + ScaleBiasData *ptr_Bias0 = static_cast(params.ref_Bias0.data()); + ScaleBiasData *ptr_Scale0 = static_cast(params.ref_Scale0.data()); + + int offset_k_0 = 0; + int offset_k_1 = 0; + + int problem_size_k_0 = params.problem_size_0.k(); + int problem_size_k_1 = params.problem_size_1.k(); + + if (params.mode == GemmUniversalMode::kGemm) { + + // Problem size is a function of threadblock index in the K dimension + problem_size_k_0 = min( + problem_size_k_0, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); + + // Problem size is a function of threadblock index in the K dimension + problem_size_k_1 = min( + problem_size_k_1, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); + + offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0; + offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1; + } + + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0; + ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A0{ + threadblock_tile_offset.m() * B2bMma::Shape0::kM, + offset_k_0, + }; + + cutlass::MatrixCoord tb_offset_B0{ + offset_k_0, + threadblock_tile_offset.n() * B2bMma::Shape0::kN + }; + + cutlass::MatrixCoord tb_offset_B1{ + offset_k_1, + threadblock_tile_offset.n() * B2bMma::Shape1::kN + }; + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; + + // Compute threadblock-scoped matrix multiply-add + // int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename B2bMma::IteratorA0 iterator_A0( + params.params_A0, + ptr_A0, + {params.problem_size_0.m(), problem_size_k_0}, + thread_idx, + tb_offset_A0); + + typename B2bMma::IteratorB0 iterator_B0( + params.params_B0, + ptr_B0, + {problem_size_k_0, params.problem_size_0.n()}, + thread_idx, + tb_offset_B0); + + typename B2bMma::IteratorB1 iterator_B1( + params.params_B1, + ptr_B1, + {problem_size_k_1, params.problem_size_1.n()}, + thread_idx, + tb_offset_B1); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // Construct iterators to accumulator scale/bias vector + typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( + ptr_Scale0, + {1, params.problem_size_0.n()}, + thread_idx, + warp_idx, + MatrixCoord( + 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN + ) + ); + + typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( + ptr_Bias0, + {1, params.problem_size_0.n()}, + thread_idx, + warp_idx, + MatrixCoord( + 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN + ) + ); + + // + // Main loop + // + + OutputOp0 output_op_0(params.output_op_0); + + if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle::value) { + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + } + + // Construct thread-scoped matrix multiply + B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n()); + + typename B2bMma::FragmentC0 src_accum; + typename B2bMma::FragmentC1 accumulators; + + src_accum.clear(); + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, + iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); + + // + // Epilogue + // + + OutputOp1 output_op_1(params.output_op_1); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * B2bMma::Shape1::kM, + threadblock_tile_offset.n() * B2bMma::Shape1::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C1 = static_cast(params.ref_C1.data()); + ElementC *ptr_D1 = static_cast(params.ref_D1.data()); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + // If performing a reduction via split-K, fetch the initial synchronization + + if (params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C1( + params.params_C1, + ptr_C1, + params.problem_size_1.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D1( + params.params_D1, + ptr_D1, + params.problem_size_1.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..cc35d91be6d8b780a7cfb1235e0f5a190ec843ab --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped B2b GEMMs +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor< + detail::GemmGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + static bool const kTransposed = Transposed; + + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params(): problem_sizes0(nullptr), problem_sizes1(nullptr), + problem_count(0), workspace(nullptr), tile_count(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const *problem_sizes0, + cutlass::gemm::GemmCoord const *problem_sizes1, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0 + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) + {} + + /// Convert the B2b-GEMM-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams(// Set problem_sizes as problem_sizes0 because these determine + // shape of the grid used in the non-grouped B2b GEMM + problem_sizes0, + problem_count, + workspace, + tile_count); + } + + }; + + // + // Methods + // + CUTLASS_DEVICE + B2bGemmGroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base ( + params_.to_base(), + shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) + {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..6794fcc1d977d2ba92fa1d678d327a1d8db39ee4 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h @@ -0,0 +1,521 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct B2bImplicitGemmConvolution { + + using B2bMma = B2bMma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp0 = typename B2bMma::OutputOp; + using EpilogueOutputOp1 = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename B2bMma::IteratorA0::Element; + using LayoutA = typename B2bMma::IteratorA0::Layout; + using ElementB = typename B2bMma::IteratorB0::Element; + using LayoutB = typename B2bMma::IteratorB0::Layout; + using ElementC = typename EpilogueOutputOp1::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp0::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp0::ElementCompute; + + /// Scale and Bias + using ElementScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Element; + using LayoutScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Layout; + + using WarpMmaOperator0 = typename B2bMma::Policy0::Operator; + using WarpMmaOperator1 = typename B2bMma::Policy1::Operator; + + using ArchMmaOperator = typename WarpMmaOperator0::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator0::OperatorClass; + using ArchTag = typename WarpMmaOperator0::ArchTag; + + using ThreadblockShape0 = typename B2bMma::Shape0; + using ThreadblockShape1 = typename B2bMma::Shape1; + using WarpShape0 = typename WarpMmaOperator0::Shape; + using WarpShape1 = typename WarpMmaOperator1::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = B2bMma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = B2bMma::IteratorA0::kIteratorAlgorithm; + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename B2bMma::WarpCount0; + static int const kThreadCount = 32 * WarpCount0::kCount; + + using TensorRefA0 = typename B2bMma::IteratorA0::TensorRef; + using TensorRefB0 = typename B2bMma::IteratorB0::TensorRef; + using TensorRefScaleBias0 = typename B2bMma::IteratorAccumulatorScaleBias::TensorRef; + using TensorRefB1 = typename B2bMma::IteratorB1::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::B2bImplicitGemmConvolution::kConvDim + static_assert(B2bMma::IteratorA0::kConvDim == B2bMma::IteratorB0::kConvDim, + "Convolution on different dimensions is not supported"); + static int const kConvDim = B2bMma::IteratorA0::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + cutlass::platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size_0; + ConvProblemSize problem_size_1; + TensorRefA0 ref_A0; + TensorRefB0 ref_B0; + TensorRefC ref_C0; + TensorRefScaleBias0 ref_Scale0; + TensorRefScaleBias0 ref_Bias0; + TensorRefB1 ref_B1; + TensorRefC ref_C1; + TensorRefC ref_D1; + typename EpilogueOutputOp0::Params output_op_0; + typename EpilogueOutputOp1::Params output_op_1; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size_0, + ConvProblemSize const & problem_size_1 + ): + problem_size_0(problem_size_0), + problem_size_1(problem_size_1) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size_0, + ConvProblemSize const & problem_size_1, + TensorRefA0 const & ref_A0, + TensorRefB0 const & ref_B0, + TensorRefC const & ref_C0, + TensorRefScaleBias0 const & ref_Scale0, + TensorRefScaleBias0 const & ref_Bias0, + TensorRefB1 const & ref_B1, + TensorRefC const & ref_C1, + TensorRefC const & ref_D1, + typename EpilogueOutputOp0::Params const & output_op_0, + typename EpilogueOutputOp1::Params const & output_op_1, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size_0(problem_size_0), + problem_size_1(problem_size_1), + ref_A0(ref_A0), + ref_B0(ref_B0), + ref_C0(ref_C0), + ref_Scale0(ref_Scale0), + ref_Bias0(ref_Bias0), + ref_B1(ref_B1), + ref_C1(ref_C1), + ref_D1(ref_D1), + output_op_0(output_op_0), + output_op_1(output_op_1), + split_k_mode(split_k_mode) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size_0; + ConvProblemSize problem_size_1; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size_0; + gemm::GemmCoord implicit_gemm_problem_size_1; + int swizzle_log_tile; + int gemm_k_iterations_0; + int gemm_k_iterations_1; + typename B2bMma::IteratorA0::Params iterator_A0; + typename B2bMma::IteratorA0::Element const *ptr_A0; + typename B2bMma::IteratorB0::Params iterator_B0; + typename B2bMma::IteratorB0::Element const *ptr_B0; + typename Epilogue::OutputTileIterator::Params iterator_C0; + typename Epilogue::OutputTileIterator::Element *ptr_C0; + typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Bias0; + typename B2bMma::IteratorB1::Params iterator_B1; + typename B2bMma::IteratorB1::Element const *ptr_B1; + typename Epilogue::OutputTileIterator::Params iterator_C1; + typename Epilogue::OutputTileIterator::Element *ptr_C1; + typename Epilogue::OutputTileIterator::Params iterator_D1; + typename Epilogue::OutputTileIterator::Element *ptr_D1; + typename EpilogueOutputOp0::Params output_op_0; + typename EpilogueOutputOp1::Params output_op_1; + int *semaphore; + SplitKMode split_k_mode; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), gemm_k_iterations_0(0), gemm_k_iterations_1(0) { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size_0(args.problem_size_0), + problem_size_1(args.problem_size_1), + implicit_gemm_problem_size_0(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0)), + implicit_gemm_problem_size_1(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1)), + iterator_A0(B2bMma::IteratorA0::getParams(args.problem_size_0, args.ref_A0.layout())), + ptr_A0(args.ref_A0.data()), + iterator_B0(args.problem_size_0, args.ref_B0.layout()), + ptr_B0(args.ref_B0.data()), + iterator_C0(ConvOutputIteratorParameter::layout(args.ref_C0)), + ptr_C0(args.ref_C0.data()), + ptr_Scale0(args.ref_Scale0.data()), + ptr_Bias0(args.ref_Bias0.data()), + iterator_B1(args.problem_size_1, args.ref_B1.layout()), + ptr_B1(args.ref_B1.data()), + iterator_C1(ConvOutputIteratorParameter::layout(args.ref_C1)), + ptr_C1(args.ref_C1.data()), + iterator_D1(ConvOutputIteratorParameter::layout(args.ref_D1)), + ptr_D1(args.ref_D1.data()), + output_op_0(args.output_op_0), + output_op_1(args.output_op_1), + semaphore(semaphore), + split_k_mode(args.split_k_mode) + { + gemm_k_iterations_0 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape0::kK, args.problem_size_0); + gemm_k_iterations_1 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape1::kK, args.problem_size_1); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size_0, + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.problem_size_0.split_k_slices); + + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename B2bMma::B2bMmaSharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + B2bImplicitGemmConvolution() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename B2bMma::IteratorA0 iterator_A0( + params.iterator_A0, + params.problem_size_0, + params.ptr_A0, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * B2bMma::Shape0::kM, + threadblock_tile_idx.k() * B2bMma::Shape0::kK + ) + ); + + typename B2bMma::IteratorB0 iterator_B0( + params.iterator_B0, + params.problem_size_0, + params.ptr_B0, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * B2bMma::Shape0::kK, + threadblock_tile_idx.n() * B2bMma::Shape0::kN + ) + ); + + typename B2bMma::IteratorB1 iterator_B1( + params.iterator_B1, + params.problem_size_1, + params.ptr_B1, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * B2bMma::Shape1::kK, + threadblock_tile_idx.n() * B2bMma::Shape1::kN + ) + ); + + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // Construct iterators to accumulator scale/bias vector + typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( + params.ptr_Scale0, + {1, params.problem_size_0.K}, + thread_idx, + warp_idx, + MatrixCoord( + 0, threadblock_tile_idx.n() * B2bMma::Shape0::kN + ) + ); + + typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( + params.ptr_Bias0, + {1, params.problem_size_0.K}, + thread_idx, + warp_idx, + MatrixCoord( + 0, threadblock_tile_idx.n() * B2bMma::Shape0::kN + ) + ); + + + // + // Main loop + // + + EpilogueOutputOp0 output_op_0(params.output_op_0); + + // Construct thread-scoped matrix multiply + B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename B2bMma::FragmentC0 src_accum; + typename B2bMma::FragmentC1 accumulators; + + src_accum.clear(); + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, + iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); + + // + // Epilogue + // + + EpilogueOutputOp1 output_op_1(params.output_op_1); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_1.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * B2bMma::Shape1::kM, + threadblock_tile_idx.n() * B2bMma::Shape1::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D1( + params.iterator_D1, + params.ptr_D1, + ConvOutputIteratorParameter::extent(params.problem_size_1), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C1( + params.iterator_C1, + params.ptr_C1, + ConvOutputIteratorParameter::extent(params.problem_size_1), + thread_idx, + threadblock_offset + ); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_idx.k()); + + __threadfence(); + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D1.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size_1)); + } + + // Run efficient epilogue + epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h new file mode 100644 index 0000000000000000000000000000000000000000..46254b5343936d2b558dd92adc0e253ee4ce2696 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "threadblock/b2b_implicit_gemm_pipelined.h" +#include "threadblock/b2b_implicit_gemm_multistage.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + bool SmemAccumulator = false +> struct DefaultB2bConv2dFprop; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h new file mode 100644 index 0000000000000000000000000000000000000000..dbb21aecd6544c87905832418513029e58e6a33e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h @@ -0,0 +1,749 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "kernel/default_b2b_conv2d_fprop.h" +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "threadblock/b2b_implicit_gemm_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelined< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + false +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelined< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelined< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelined< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..6e0af4f96592a835a4ca818b13b658c9947d7cd8 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h @@ -0,0 +1,740 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "kernel/default_b2b_conv2d_fprop.h" +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "threadblock/b2b_implicit_gemm_multistage.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistage< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistage< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +/// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistage< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +// multistage pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistage< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h new file mode 100644 index 0000000000000000000000000000000000000000..35a5681dd429f2497a20820dfaafa297c2803c39 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h @@ -0,0 +1,817 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "kernel/default_b2b_conv2d_fprop.h" +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm +/// and 2 stage pipeline. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; //For interleaved layout + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm +/// and 2 stage pipeline. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; //For interleaved layout + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..ac80be8e77989bb6b53b35cb5f6a67004040aa87 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h @@ -0,0 +1,804 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "kernel/default_b2b_conv2d_fprop.h" +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline with interleaved layout. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +/// multistage pipeline. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +// multistage pipeline with interleaved layout. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..e2cc94377c0e972ee39fc051c6798368c69d2041 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h @@ -0,0 +1,503 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#include "kernel/b2b_gemm.h" +#include "kernel/grouped.h" +#include "threadblock/default_b2b_mma.h" +#include "threadblock/grouped_threadblock_swizzle.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template +using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle; + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Stage accumulator in shared memory + bool SmemAccumulator = false, + /// Whether or not the operation is grouped + typename Enable = void +> +struct DefaultB2bGemm; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm::value>::type> { + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +/// Partial specialization for Ampere Architecture with grouped operation +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm::value>::type> { + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using UnderlyingB2bGemmKernel = kernel::B2bGemm; + + using B2bGemmKernel = kernel::GroupedKernel; +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Turing Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Operation performed by GEMM + typename Operator +> +struct DefaultB2bGemm< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC, layout::RowMajor, + ElementAccumulator, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + Operator, + false, + typename platform::enable_if::value>::type +> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + 2, + Operator, + EpilogueOutputOp0 + >::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + typename B2bMma::Operator1, + kPartitionsK1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + + +/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Number of Interleaved k + int InterleavedK, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm< + ElementA, layout::ColumnMajorInterleaved, kAlignmentA, + ElementB, layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, int32_t, + arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + ThreadblockSwizzle, Stages, + Operator, false, typename platform::enable_if::value>::type> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0, + true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + 64 / sizeof_bits::value, InterleavedK>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Partial specialization for Turing Integer Tensor Core Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of Interleaved k + int InterleavedK, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm, + kAlignmentA, ElementB, + layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, + int32_t, arch::OpClassTensorOp, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + ThreadblockSwizzle, 2, Operator, false, + typename platform::enable_if::value>::type> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, + arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, + WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue for the 2nd Gemm + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + 64 / sizeof_bits::value, InterleavedK>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..0a4530f6ce82c18656d7bb64452ca08b1eebfe18 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" + +#include "kernel/b2b_gemm.h" +#include "threadblock/default_b2b_mma.h" +#include "threadblock/default_b2b_mma_smem_accumulator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm { + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Turing Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Operation performed by GEMM + typename Operator +> +struct DefaultB2bGemm< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC, layout::RowMajor, + ElementAccumulator, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + Operator, + true +> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + 2, + Operator, + EpilogueOutputOp0, + false, + true + >::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + typename B2bMma::Operator1, + kPartitionsK1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + + +/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Number of Interleaved k + int InterleavedK, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm< + ElementA, layout::ColumnMajorInterleaved, kAlignmentA, + ElementB, layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, int32_t, + arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + ThreadblockSwizzle, Stages, + Operator, true> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0, + true, true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + 64 / sizeof_bits::value, InterleavedK>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Partial specialization for Turing Integer Tensor Core Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of Interleaved k + int InterleavedK, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm, + kAlignmentA, ElementB, + layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, + int32_t, arch::OpClassTensorOp, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + ThreadblockSwizzle, 2, Operator, true> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue for the 2nd Gemm + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + 64 / sizeof_bits::value, InterleavedK>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/grouped.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..0ac841d4edf08b2d846109f64e343525165b07b9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/grouped.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief High-level interface for running a grouped version of a CUTLASS kernel +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// High-level interface for running a grouped version of a CUTLASS kernel +template < + typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate +> +struct GroupedKernel { +public: + + using BaseKernel = BaseKernel_; + using Epilogue = typename BaseKernel::Epilogue; + + /// Types that need to be exported to work properly with device::BaseGrouped + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + using Mma = typename BaseKernel::Mma; + + using Arguments = typename BaseKernel::GroupedArguments; + using Params = typename BaseKernel::GroupedParams; + using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor; + + static int const kThreadCount = BaseKernel::kThreadCount; + + /// Shared memory storage structure + struct SharedStorage { + typename BaseKernel::SharedStorage kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GroupedKernel() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + /// Executes a kernel-level GEMM in a loop + CUTLASS_DEVICE + void operator()(Params ¶ms, SharedStorage &shared_storage) { + + ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + if (ProblemVisitor::kTransposed) { + params.transpose(); + } + + BaseKernel mma; + + // Outer 'persistent' loop to iterate over tiles + while (swizzle.problem_visitor.next_tile()) { + + typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor); + mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle); + + // Next tile + swizzle.problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h new file mode 100644 index 0000000000000000000000000000000000000000..4bf3c532559cb27b7d638319276b447817a3cc49 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h @@ -0,0 +1,395 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" + +#include "cutlass/gemm/gemm.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template < + typename TensorRefIn, ///< Input TensorRef Type + typename TensorRefOut, ///< Output TensorRef Type + typename ScalarType, ///< alpha Type + typename TensorRefScalar, ///< Scale/Bias TensorRef Type + typename OutputTile, + typename ConvertOp = NumericConverter +> +__global__ void TensorScaleBiasGemm( + gemm::GemmCoord problem_size, + TensorRefIn tensor_in, ///< input tensor + TensorRefOut tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRefScalar tensor_scale, ///< scale tensor + TensorRefScalar tensor_bias ///< bias tensor +) { + + ConvertOp convert_op; + + MatrixCoord output_coord( + MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), + MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) + ); + + // Update the output tensor + for (int j = 0; j < OutputTile::kRow; ++j) { + for (int i = 0; i < OutputTile::kColumn; ++i) { + MatrixCoord coord = output_coord + MatrixCoord(i, j); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + ScalarType scale = alpha; + if(tensor_scale.good()) + scale = tensor_scale.at({0, coord.column()}); + + ScalarType bias = ScalarType(0); + + if(tensor_bias.good()) + bias = tensor_bias.at({0, coord.column()}); + + tensor_out.at(coord) = convert_op( + scale * ScalarType(tensor_in.at(coord)) + bias); + } + } + } +} + +template < + typename TensorRefIn, ///< Input TensorRef Type + typename TensorRefOut, ///< Output TensorRef Type + typename ScalarType, ///< alpha Type + typename TensorRefScalar, ///< Scale/Bias TensorRef Type + typename ConvertOp = NumericConverter, + int kMblock = 4, + int kNblock = 4 +> +__global__ void TensorScaleBiasGemmBatched( + gemm::GemmCoord problem_size, + TensorRefIn tensor_in, ///< input tensor + TensorRefOut tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRefScalar tensor_scale, ///< scale tensor + TensorRefScalar tensor_bias, ///< bias tensor + int batch_count = 1, + int64_t batch_stride_tensor_in = 0, + int64_t batch_stride_tensor_out = 0, + int64_t batch_stride_tensor_scale = 0, + int64_t batch_stride_tensor_bias = 0 +) { + + ConvertOp convert_op; + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in); + tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out); + tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale); + tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + MatrixCoord coord = MatrixCoord(row, col); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + ScalarType scale = alpha; + if(tensor_scale.good()) + scale = tensor_scale.at({0, coord.column()}); + + ScalarType bias = ScalarType(0); + + if(tensor_bias.good()) + bias = tensor_bias.at({0, coord.column()}); + + tensor_out.at(coord) = convert_op( + scale * ScalarType(tensor_in.at(coord)) + bias); + } + } + } + tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z); + tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z); + tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z); + tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z); + } +} + +template < + typename TensorRefIn, ///< Input TensorRef Type + typename TensorRefOut, ///< Output TensorRef Type + typename ScalarType, ///< alpha Type + typename TensorRefScalar, ///< Scale/Bias TensorRef Type + typename ConvertOp = NumericConverter, + int kThreadM = 4, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void TensorScaleBiasConv2d( + conv::Conv2dProblemSize problem_size, + TensorRefIn tensor_in, ///< input tensor + TensorRefOut tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRefScalar tensor_scale, ///< scale tensor + TensorRefScalar tensor_bias ///< bias tensor +) { + + ConvertOp convert_op; + + int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t npq = npq_start + m; + + thread_n[m] = int(npq / PQ); + + int64_t residual = npq % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ScalarType scale = alpha; + if(tensor_scale.good()) + scale = tensor_scale.at({0, thread_k}); + + ScalarType bias = ScalarType(0); + if(tensor_bias.good()) + bias = tensor_bias.at({0, thread_k}); + + tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + scale * ScalarType( + tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) + ) + bias); + } + } + } + } + +} + +} + +/// Apply scale and bias on a tensor +template < + typename ElementIn, ///< Input Type + typename ElementOut, ///< Output Type + typename Layout, ///< Layout of input/output tensor + typename ScalarType, ///< alpha Type + typename LayoutScaleBias, ///< Layout of scale and bias + typename ConvertOp = NumericConverter +> +void TensorScaleBiasGemm( + gemm::GemmCoord problem_size, + TensorRef tensor_in, ///< input tensor + TensorRef tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRef tensor_scale, ///< scale tensor + TensorRef tensor_bias ///< bias tensor +) { + + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + kernel::TensorScaleBiasGemm< + TensorRef, + TensorRef, + ScalarType, + TensorRef, + OutputTile, + ConvertOp + ><<< grid, block >>> ( + problem_size, + tensor_in, + tensor_out, + alpha, + tensor_scale, + tensor_bias + ); +} + +/// Apply scale and bias on a tensor +template < + typename ElementIn, ///< Input Type + typename ElementOut, ///< Output Type + typename Layout, ///< Layout of input/output tensor + typename ScalarType, ///< alpha Type + typename LayoutScaleBias, ///< Layout of scale and bias + typename ConvertOp = NumericConverter +> +void TensorScaleBiasGemmBatched( + gemm::GemmCoord problem_size, + TensorRef tensor_in, ///< input tensor + TensorRef tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRef tensor_scale, ///< scale tensor + TensorRef tensor_bias, ///< bias tensor + int batch_count = 1, + int64_t batch_stride_tensor_in = 0, + int64_t batch_stride_tensor_out = 0, + int64_t batch_stride_tensor_scale = 0, + int64_t batch_stride_tensor_bias = 0 +) { + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::TensorScaleBiasGemmBatched< + TensorRef, + TensorRef, + ScalarType, + TensorRef, + ConvertOp, + kMblock, + kNblock + ><<< grid, block >>> ( + problem_size, + tensor_in, + tensor_out, + alpha, + tensor_scale, + tensor_bias, + batch_count, + batch_stride_tensor_in, + batch_stride_tensor_out, + batch_stride_tensor_scale, + batch_stride_tensor_bias + ); +} + +/// Apply scale and bias on a tensor +template < + typename ElementIn, ///< Input Type + typename ElementOut, ///< Output Type + typename Layout, ///< Layout of input/output tensor + typename ScalarType, ///< alpha Type + typename LayoutScaleBias, ///< Layout of scale and bias + typename ConvertOp = NumericConverter +> +void TensorScaleBiasConv2d( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_in, ///< input tensor + TensorRef tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRef tensor_scale, ///< scale tensor + TensorRef tensor_bias ///< bias tensor +) { + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; + int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + + kernel::TensorScaleBiasConv2d< + TensorRef, + TensorRef, + ScalarType, + TensorRef, + ConvertOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block >>> ( + problem_size, + tensor_in, + tensor_out, + alpha, + tensor_scale, + tensor_bias + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/test_run.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/test_run.h new file mode 100644 index 0000000000000000000000000000000000000000..1fba44d69655d8d8f14442c411da4cafd3db598f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/test_run.h @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#include + +// Run tests on GPUs + +int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { + + bool supported = false; + + int arch_major = arch / 10; + int arch_minor = arch - arch / 10 * 10; + + if(arch_major >= 8) { + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { + supported = true; + } + } + else if(arch_major >= 7) { + // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. + if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { + supported = true; + } + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!(props.major == arch_major && props.minor == arch_minor)) { + supported = false; + } + + if (!supported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + std::cout << "This example isn't supported on current architecture" << std::endl; + return 0; + } + + bool pass = true; + + std::cout << "Device: " << props.name << std::endl; + std::cout << "Arch: SM" << arch << std::endl; + std::cout << "Test: " << test_name << std::endl; + for(auto func : test_funcs) { + pass &= func(); + } + + + if(pass) + return 0; + else + return -1; + +} + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..1feb71cf11ceb30517e99dc25275644e6060b2db --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h @@ -0,0 +1,831 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA0, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB0, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// WarpIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bImplicitGemmMultistage : + public gemm::threadblock::B2bMmaBase { +public: + ///< Base class + using Base = gemm::threadblock::B2bMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + ///< Iterates over tiles of A operand in global memory + using IteratorA0 = IteratorA0_; + ///< Iterates over tiles of B operand in global memory + using IteratorB0 = IteratorB0_; + ///< Policy describing tuning details + using Policy0 = Policy0_; + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of A operand in global memory + using FragmentIteratorA1 = FragmentIteratorA1_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; + ///< WarpIterator to load Scale or Bias vector from threadblock fragment + using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + + ///< Epilogue after 1st Gemm + using OutputOp = OutputOp_; + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + + // + // Dependent types + // + + using ElementC = typename Policy0::Operator::ElementC; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations0 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA0 = + IteratorA0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB0 = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA0 = + (AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB0 = + (AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + private: + + using WarpLoadedFragmentA0 = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; + using WarpLoadedFragmentA1ScaleBias = + typename FragmentIteratorA1ScaleBias::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bImplicitGemmMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_0( + IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, + int group_start_A0 = 0, int group_start_B0 = 0) { + + iterator_A0.set_iteration_index(group_start_A0); + this->smem_iterator_A0_.set_iteration_index(group_start_A0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { + + if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + + ++this->smem_iterator_A0_; + } + } + + iterator_B0.set_iteration_index(group_start_B0); + + this->smem_iterator_B0_.set_iteration_index(group_start_B0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { + if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + ++this->smem_iterator_B0_; + } + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1 &iterator_B1, + int group_start_B1 = 0) { + + iterator_B1.set_iteration_index(group_start_B1); + + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_0, + ///< destination accumulator tile + FragmentC1 &accum, + ///< iterator over A0 operand in global memory + IteratorA0 iterator_A0, + ///< iterator over B0 operand in global memory + IteratorB0 iterator_B0, + ///< iterator over A1 operand scale vector in global memory + IteratorAccumulatorScaleBias iterator_A1_scale, + ///< iterator over A1 operand bias vector in global memory + IteratorAccumulatorScaleBias iterator_A1_bias, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC0 const &src_accum, + ///< epilogue operation after 1st Gemm + OutputOp output_op_0, + ///< Imaginary strides used for planar-complex only - ignored here + int64_t imag_stride_A = 0, + int64_t imag_stride_B = 0) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_0) { + + iterator_A0.set_iteration_index(0); + this->smem_iterator_A0_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + ++this->smem_iterator_A0_; + } + + iterator_B0.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + ++this->smem_iterator_B0_; + } + + // Move to the next stage + iterator_A0.advance(); + iterator_B0.advance(); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + + Operator0 warp_mma0; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance_0(iterator_A0, iterator_B0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { + + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k > 0) + warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A0[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_A0, group_start_iteration_B0; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) { + group_start_iteration_A0 = 0; + group_start_iteration_B0 = 0; + } else { + group_start_iteration_A0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; + } + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + + warp_mma0( + accum0, + warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) + warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A0[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A0.advance(); + iterator_B0.advance(); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations_0; + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + + // 2nd Implicit Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + FragmentA1ScaleBias tb_frag_A1_scale; + FragmentA1ScaleBias tb_frag_A1_bias; + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); + + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + + + // + // Prologue + // + int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_1) { + + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.advance(); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + if(PerChannelScale) { + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + } + warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]); + ++warp_tile_iterator_A1_bias_; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], + warp_loaded_frag_A1_scale[0], + warp_loaded_frag_A1_bias[0], + output_op_0); + ++warp_tile_iterator_A1_; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B1_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance_1(iterator_B1); + + smem_write_stage_idx = Base::kStages - 1; + smem_read_stage_idx = 0; + + warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); + + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + + // Load threadblock-level scale/bias vector from global memory + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + } + + // Load warp-level scale bias fragment from threadblock scale/bias vector + if(PerChannelScale) { + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + ++warp_tile_iterator_A1_scale_; + } + warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]); + ++warp_tile_iterator_A1_bias_; + + // Load warp-level tile from accumulator fragment + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); + ++warp_tile_iterator_A1_; + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k > 0) + warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_B1; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { + group_start_iteration_B1 = 0; + } else { + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + } + + copy_tiles_and_advance_1(iterator_B1, + group_start_iteration_B1); + + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum + ); + + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.advance(); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..64181870b5a76afd7112d74f5f0dc1913ada6323 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h @@ -0,0 +1,816 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base_smem_accumulator.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA0, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB0, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// Iterates over accumulator tile + typename FragmentIteratorAccumulator_, + /// Iterates over accumulator tile in shared memory + typename SmemIteratorD0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename WarpIteratorA1_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bImplicitGemmMultistageSmemAccumulator : + public gemm::threadblock::B2bMmaBaseSmemAccumulator { +public: + ///< Base class + using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + ///< Iterates over tiles of A operand in global memory + using IteratorA0 = IteratorA0_; + ///< Iterates over tiles of B operand in global memory + using IteratorB0 = IteratorB0_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; + ///< Policy describing tuning details + using Policy0 = Policy0_; + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory + + using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory + + ///< Epilogue after 1st Gemm + using OutputOp = OutputOp_; + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + + // + // Dependent types + // + + using ElementC = typename Policy0::Operator::ElementC; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Epilog in shared memory + using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, ///< SmemTileIterator + FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator + IteratorAccumulatorScaleBias, ///< ScaleBiasIterator + OutputOp>; ///< Output operator + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations0 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA0 = + IteratorA0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB0 = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA0 = + (AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB0 = + (AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + private: + + using WarpLoadedFragmentA0 = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Shared Memory Iterator to store accumulator tile + SmemIteratorD0 smem_iterator_D0_; + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bImplicitGemmMultistageSmemAccumulator( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; + int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; + + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset( + {warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0}); + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + + // Add smem accumulator iterator warp offset + smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, + warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_0( + IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, + int group_start_A0 = 0, int group_start_B0 = 0) { + + iterator_A0.set_iteration_index(group_start_A0); + this->smem_iterator_A0_.set_iteration_index(group_start_A0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { + + if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + + ++this->smem_iterator_A0_; + } + } + + iterator_B0.set_iteration_index(group_start_B0); + + this->smem_iterator_B0_.set_iteration_index(group_start_B0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { + if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + ++this->smem_iterator_B0_; + } + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1 &iterator_B1, + int group_start_B1 = 0) { + + iterator_B1.set_iteration_index(group_start_B1); + + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_0, + ///< destination accumulator tile + FragmentC1 &accum, + ///< iterator over A0 operand in global memory + IteratorA0 iterator_A0, + ///< iterator over B0 operand in global memory + IteratorB0 iterator_B0, + ///< iterator over A1 operand scale vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_scale, + ///< iterator over A1 operand bias vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_bias, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC0 const &src_accum, + ///< epilogue operation after 1st Gemm + OutputOp output_op_0, + ///< Imaginary strides used for planar-complex only - ignored here + int64_t imag_stride_A = 0, + int64_t imag_stride_B = 0) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_0) { + + iterator_A0.set_iteration_index(0); + this->smem_iterator_A0_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + ++this->smem_iterator_A0_; + } + + iterator_B0.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + ++this->smem_iterator_B0_; + } + + // Move to the next stage + iterator_A0.advance(); + iterator_B0.advance(); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + + Operator0 warp_mma0; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance_0(iterator_A0, iterator_B0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { + + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k > 0) + warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A0[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_A0, group_start_iteration_B0; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) { + group_start_iteration_A0 = 0; + group_start_iteration_B0 = 0; + } else { + group_start_iteration_A0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; + } + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + + warp_mma0( + accum0, + warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) + warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A0[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A0.advance(); + iterator_B0.advance(); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations_0; + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + /// Epilogue for the first Implicit Gemm + Epilogue0 epilogue0; + + epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); + + __syncthreads(); + + // 2nd Implicit Gemm + + // + // Prologue + // + int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_1) { + + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.advance(); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B1_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance_1(iterator_B1); + + smem_write_stage_idx = Base::kStages - 1; + smem_read_stage_idx = 0; + + warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); + + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + + // Load warp-level tile from accumulator fragment + // skip warp tile loading for the last kgroup + if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B1_; + + + if (warp_mma_k > 0) + warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_B1; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { + group_start_iteration_B1 = 0; + } else { + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + } + + copy_tiles_and_advance_1(iterator_B1, + group_start_iteration_B1); + + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum + ); + + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.advance(); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..97466a1ce72e632f6436b49a3f0c578b03f543da --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h @@ -0,0 +1,553 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// FragmentIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, + // (concept: VectorFragmentIterator) + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Transformation applied to A operand + typename TransformA0_ = NumericArrayConverter< + typename SmemIteratorA0_::Element, + typename IteratorA0_::Element, + IteratorA0_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB0_ = NumericArrayConverter< + typename SmemIteratorB0_::Element, + typename IteratorB0_::Element, + IteratorB0_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB1_ = NumericArrayConverter< + typename SmemIteratorB1_::Element, + typename IteratorB1_::Element, + IteratorB1_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class B2bImplicitGemmPipelined : + public gemm::threadblock::B2bMmaBase { +public: + + ///< Base class + using Base = gemm::threadblock::B2bMmaBase; + + using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using Policy0 = Policy0_; ///< Policy0 describing tuning details + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over tiles of A1 operand from accumulator tile + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory + using FragmentIteratorA1ScaleBias = + FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment + using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory + using Policy1 = Policy1_; ///< Policy1 describing tuning details + + using SmemIteratorB1 = SmemIteratorB1_; + + + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + + using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + using TransformA0 = TransformA0_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA0 = typename IteratorA0::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB0 = typename IteratorB0::Fragment; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB1 = typename IteratorB1::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy0::Operator::ArchTag; + + /// Complex transform on A0 operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B0 operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B1 operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA0 = typename Operator0::FragmentA; + using WarpFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; + /// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment + using WarpFragmentA1ScaleBias = + typename FragmentIteratorA1ScaleBias::Fragment; + using WarpFragmentB1 = typename Operator1::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bImplicitGemmPipelined( + typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + //These may change across different GEMM layers + int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k; + int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n}); + + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations_0, ///< number of iterations of the mainloop + FragmentC1 &accum, ///< destination accumulator tile + IteratorA0 iterator_A, ///< iterator over A operand in global memory + IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory + IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory + IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory + FragmentC0 const &src_accum, ///< source accumulator tile + OutputOp output_op_0, ///< epilogue operation after 1st Gemm + TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment + TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment + TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + FragmentA0 tb_frag_A; + FragmentB0 tb_frag_B0; + + tb_frag_A.clear(); + tb_frag_B0.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA0 warp_frag_A0[2]; + WarpFragmentB0 warp_frag_B0[2]; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + Operator0 warp_mma0; + + int smem_write_stage_idx = 1; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + } + + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], + warp_frag_B0[warp_mma_k % 2], accum0); + + } + } + + + //2nd Implicit Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + + + + // + // Prologue + // + + FragmentA1ScaleBias tb_frag_A1_scale; + FragmentA1ScaleBias tb_frag_A1_bias; + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); + FragmentB1 tb_frag_B1; + + if(PerChannelScale) + tb_frag_A1_scale.clear(); + tb_frag_A1_bias.clear(); + tb_frag_B1.clear(); + + // The last kblock is loaded in the prolog + if(PerChannelScale) + iterator_A1_scale.load(tb_frag_A1_scale); + iterator_A1_bias.load(tb_frag_A1_bias); + iterator_B1.load(tb_frag_B1); + + + if(PerChannelScale) + ++iterator_A1_scale; + ++iterator_A1_bias; + ++iterator_B1; + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + ++this->smem_iterator_B1_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA1ScaleBias warp_frag_A1_scale[2]; + WarpFragmentA1ScaleBias warp_frag_A1_bias[2]; + WarpFragmentA1 warp_frag_A1[2]; + WarpFragmentB1 warp_frag_B1[2]; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + if(PerChannelScale) + warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]); + warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]); + warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0], + warp_frag_A1_bias[0], output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + + ++warp_tile_iterator_A1_; + if(PerChannelScale) + ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; + ++this->warp_tile_iterator_B1_; + + Operator1 warp_mma1; + + smem_write_stage_idx = 1; + + int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + __syncthreads(); + + ++this->smem_iterator_B1_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * Base::kWarpGemmIterations1, + 0}); + } + + smem_write_stage_idx ^= 1; + + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + } + + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + + if(PerChannelScale) + warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], + warp_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + + if(PerChannelScale) + ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k == 0) { + + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + } + + warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], + warp_frag_B1[warp_mma_k % 2], accum); + + } + } + + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..e5d91b127ca45d99a57f1cc656d59e491335dcbe --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h @@ -0,0 +1,535 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base_smem_accumulator.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// Iterates over accumulator tile + typename FragmentIteratorAccumulator_, + /// Iterates over accumulator tile in shared memory + typename SmemIteratorD0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Transformation applied to A0 operand + typename TransformA0_ = NumericArrayConverter< + typename SmemIteratorA0_::Element, + typename IteratorA0_::Element, + IteratorA0_::Fragment::kElements>, + /// + /// Transformation applied to B0 operand + typename TransformB0_ = NumericArrayConverter< + typename SmemIteratorB0_::Element, + typename IteratorB0_::Element, + IteratorB0_::Fragment::kElements>, + /// + /// Transformation applied to B1 operand + typename TransformB1_ = NumericArrayConverter< + typename SmemIteratorB1_::Element, + typename IteratorB1_::Element, + IteratorB1_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class B2bImplicitGemmPipelinedSmemAccumulator : + public gemm::threadblock::B2bMmaBaseSmemAccumulator { +public: + + ///< Base class + using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; + + using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory + using Policy0 = Policy0_; ///< Policy0 describing tuning details + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory + + using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile + + using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory + using Policy1 = Policy1_; ///< Policy1 describing tuning details + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory + + + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + + using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + + using TransformA0 = TransformA0_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA0 = typename IteratorA0::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB0 = typename IteratorB0::Fragment; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of operand B loaded from global memory + using FragmentB1 = typename IteratorB1::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy0::Operator::ArchTag; + + /// Complex transform on A0 operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B0 operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B1 operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + + /// Epilog in shared memory + using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, ///< SmemTileIterator + FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator + IteratorAccumulatorScaleBias, ///< ScaleBiasIterator + OutputOp>; ///< Output operator + + + +private: + + using WarpFragmentA0 = typename Operator0::FragmentA; + using WarpFragmentB0 = typename Operator0::FragmentB; + using WarpFragmentA1 = typename Operator1::FragmentA; + using WarpFragmentB1 = typename Operator1::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Shared Memory Iterator to store accumulator tile + SmemIteratorD0 smem_iterator_D0_; + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bImplicitGemmPipelinedSmemAccumulator( + typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; + int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; + + int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0; + + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0}); + warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1}); + this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1}); + + // Add smem accumulator iterator warp offset + smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, + warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); + + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations_0, ///< number of iterations of the mainloop + FragmentC1 &accum, ///< destination accumulator tile + IteratorA0 iterator_A, ///< iterator over A operand in global memory + IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory + IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory + FragmentC0 const &src_accum, ///< source accumulator tile + OutputOp output_op_0, ///< epilogue operation after 1st Gemm + TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment + TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment + TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + FragmentA0 tb_frag_A; + FragmentB0 tb_frag_B0; + + tb_frag_A.clear(); + tb_frag_B0.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA0 warp_frag_A0[2]; + WarpFragmentB0 warp_frag_B0[2]; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + Operator0 warp_mma0; + + int smem_write_stage_idx = 1; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + ++iterator_A; + ++iterator_B0; + } + + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], + warp_frag_B0[warp_mma_k % 2], accum0); + + } + } + + /// Epilogue for the first Implicit Gemm + Epilogue0 epilogue0; + + epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); + + __syncthreads(); + + /// 2nd Implicit Gemm + + + // + // Prologue + // + + FragmentB1 tb_frag_B1; + + tb_frag_B1.clear(); + + // The last kblock is loaded in the prolog + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + ++this->smem_iterator_B1_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA1 warp_frag_A1[2]; + WarpFragmentB1 warp_frag_B1[2]; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + warp_tile_iterator_A1_.load(warp_frag_A1[0]); + this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + Operator1 warp_mma1; + + smem_write_stage_idx = 1; + + int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; + + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + + // Write fragments to shared memory + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + __syncthreads(); + + ++this->smem_iterator_B1_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + } + + smem_write_stage_idx ^= 1; + + } + + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + + // skip warp tile loading for the last kgroup + if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1) + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k == 0) { + + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + } + + warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], + warp_frag_B1[warp_mma_k % 2], accum); + + } + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..c845f2023f83e09303efc0df2dc0d560b5f080ca --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + using Shape1 = Shape1_; + + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm0 = typename Policy0::Operator::Shape; + using WarpGemm1 = typename Policy1::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount0 = GemmShape; + using WarpCount1 = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations0 = + (WarpGemm0::kK / Operator0::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = + (WarpGemm1::kK / Operator1::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template< + typename Shape_, + typename Policy_ + > + class SharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Policy = Policy_; + using Operator = typename Policy::Operator; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + using SharedStorage0 = SharedStorage; + using SharedStorage1 = SharedStorage; + union B2bMmaSharedStorage { + SharedStorage0 shared_storage0; + SharedStorage1 shared_storage1; + }; + + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A0 operand from shared memory + typename Operator0::IteratorA warp_tile_iterator_A0_; + + /// Iterator to load a warp-scoped tile of B0 operand from shared memory + typename Operator0::IteratorB warp_tile_iterator_B0_; + + /// Iterator to load a warp-scoped tile of B1 operand from shared memory + typename Operator1::IteratorB warp_tile_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), lane_idx), + warp_tile_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), lane_idx), + warp_tile_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..c0356df46fb18cedbf984a0058e7d680e2a6650a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h @@ -0,0 +1,179 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "threadblock/b2b_mma_base.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Shared Memory Accumulator Iterator + typename SmemAccumulatorIterator0_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaBaseSmemAccumulator : + public B2bMmaBase { + + public: + ///< Base class + using Base = B2bMmaBase; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + using Shape1 = Shape1_; + + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + + using SmemAccumulatorIterator0 = SmemAccumulatorIterator0_; + + // + // Nested structs + // + /// Shared storage object needed by accumulator + template< + typename Shape_, + typename Element_, + typename Layout_, + typename Padding_ + > + class AccumulatorSharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for accumulator + AlignedBuffer accum; + + public: + + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { + return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); + } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { + return TensorRefAccum{accum.data(), LayoutAccum()}; + } + + }; + + using AccumulatorSharedStorage0 = AccumulatorSharedStorage< + Shape0, typename SmemAccumulatorIterator0::Element, + typename SmemAccumulatorIterator0::TensorLayout, + typename SmemAccumulatorIterator0::Padding>; + + struct B2bMmaSharedStorage { + typename Base::B2bMmaSharedStorage b2b_mma_shared_storage; + AccumulatorSharedStorage0 accumulator_shared_storage0; + }; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaBaseSmemAccumulator( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage.b2b_mma_shared_storage, thread_idx, warp_idx, lane_idx) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..b9388a73316d4c806ab62a548e67312008ab9527 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h @@ -0,0 +1,903 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA0, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB0, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// WarpIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaMultistage : + public B2bMmaBase { +public: + ///< Base class + using Base = B2bMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + ///< Iterates over tiles of A operand in global memory + using IteratorA0 = IteratorA0_; + using IteratorA = IteratorA0; + ///< Iterates over tiles of B operand in global memory + using IteratorB0 = IteratorB0_; + using IteratorB = IteratorB0; + ///< Policy describing tuning details + using Policy0 = Policy0_; + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over intermediate accumulator tile + using FragmentIteratorA1 = FragmentIteratorA1_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; + ///< WarpIterator to load Scale or Bias vector from threadblock fragment + using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + ///< Export Policy0 as the threadblock-level Mma's policy + using Policy = Policy0; + using Shape = Shape0; + + using SmemIteratorB1 = SmemIteratorB1_; + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + ///< Epilogue after 1st Gemm + using OutputOp = OutputOp_; + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations0 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const TBLoadIterationsA0 = + IteratorA0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB0 = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA0 = + (TBLoadIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB0 = + (TBLoadIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + private: + + using WarpLoadedFragmentA0 = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; + using WarpLoadedFragmentA1ScaleBias = + typename FragmentIteratorA1ScaleBias::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, + int group_start_A0 = 0, int group_start_B0 = 0) { + iterator_A0.set_iteration_index(group_start_A0 * + IteratorA0::kAccessesPerVector); + this->smem_iterator_A0_.set_iteration_index(group_start_A0); + + // Load for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { + if (group_start_A0 + j < Detail::TBLoadIterationsA0) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / + IteratorA0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A0.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } + } + + iterator_B0.set_iteration_index(group_start_B0 * + IteratorB0::kAccessesPerVector); + this->smem_iterator_B0_.set_iteration_index(group_start_B0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { + if (group_start_B0 + j < Detail::TBLoadIterationsB0) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B0.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + + ++iterator_B0; + } + ++this->smem_iterator_B0_; + } + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1(IteratorB1 &iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index(group_start_B1 * + IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_0, + ///< destination accumulator tile + FragmentC1 &accum, + ///< iterator over A0 operand in global memory + IteratorA0 iterator_A0, + ///< iterator over B0 operand in global memory + IteratorB0 iterator_B0, + ///< iterator over A1 operand scale vector in global memory + IteratorAccumulatorScaleBias iterator_A1_scale, + ///< iterator over A1 operand bias vector in global memory + IteratorAccumulatorScaleBias iterator_A1_bias, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC0 const &src_accum, + ///< epilogue operation after 1st Gemm + OutputOp output_op_0) + { + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_0) { + + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + + iterator_A0.set_iteration_index(0); + this->smem_iterator_A0_.set_iteration_index(0); + + // Load for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsA0; ++j) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / + IteratorA0::kAccessesPerVector / 8; + + int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } + + iterator_B0.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB0; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + } + + ++this->smem_iterator_B0_; + } + + // Move to the next stage + iterator_A0.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + + Operator0 warp_mma0; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k > 0) + warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A0[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + + warp_mma0( + accum0, + warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations0 - 1) { + int group_start_iteration_A0, group_start_iteration_B0; + + group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0; + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { + int group_start_iteration_A0, group_start_iteration_B0; + group_start_iteration_A0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A0.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations_0; + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) + warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A0[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + } + + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + // 2nd Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + FragmentA1ScaleBias tb_frag_A1_scale; + FragmentA1ScaleBias tb_frag_A1_bias; + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); + + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + + // + // Prologue + // + int gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_1) { + + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + if(PerChannelScale) { + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + } + warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]); + ++warp_tile_iterator_A1_bias_; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], + warp_loaded_frag_A1_scale[0], + warp_loaded_frag_A1_bias[0], + output_op_0); + ++warp_tile_iterator_A1_; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B1_; + + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + smem_write_stage_idx = Base::kStages - 1; + smem_read_stage_idx = 0; + + warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); + + // + // Mainloop + // + + gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1 - (Base::kStages - 1); + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + + // Load threadblock-level scale/bias vector from global memory + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + } + + // Load warp-level scale bias fragment from threadblock scale/bias vector + if(PerChannelScale) { + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + ++warp_tile_iterator_A1_scale_; + } + warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]); + ++warp_tile_iterator_A1_bias_; + + // Load warp-level tile from accumulator fragment + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); + ++warp_tile_iterator_A1_; + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k > 0) + warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + + + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..b089bdba2bea2069dd9755b09e680278a771fe2a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h @@ -0,0 +1,887 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base_smem_accumulator.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA0, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB0, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// Iterates over accumulator tile + typename FragmentIteratorAccumulator_, + /// Iterates over accumulator tile in shared memory + typename SmemIteratorD0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaMultistageSmemAccumulator : + public gemm::threadblock::B2bMmaBaseSmemAccumulator { +public: + ///< Base class + using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + ///< Iterates over tiles of A operand in global memory + using IteratorA0 = IteratorA0_; + using IteratorA = IteratorA0; + ///< Iterates over tiles of B operand in global memory + using IteratorB0 = IteratorB0_; + using IteratorB = IteratorB0; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; + ///< Policy describing tuning details + using Policy0 = Policy0_; + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory + + using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + ///< Export Policy0 as the threadblock-level Mma's policy + using Policy = Policy0; + using Shape = Shape0; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + ///< Epilogue after 1st Gemm + using OutputOp = OutputOp_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Epilog in shared memory + using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, ///< SmemTileIterator + FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator + IteratorAccumulatorScaleBias, ///< ScaleBiasIterator + OutputOp>; ///< Output operator + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations0 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const TBLoadIterationsA0 = + IteratorA0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB0 = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA0 = + (TBLoadIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB0 = + (TBLoadIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + private: + + using WarpLoadedFragmentA0 = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Shared Memory Iterator to store accumulator tile + SmemIteratorD0 smem_iterator_D0_; + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaMultistageSmemAccumulator( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx ), + smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; + int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; + + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset( + {warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0}); + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + + // Add smem accumulator iterator warp offset + smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, + warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, + int group_start_A0 = 0, int group_start_B0 = 0) { + iterator_A0.set_iteration_index(group_start_A0 * + IteratorA0::kAccessesPerVector); + this->smem_iterator_A0_.set_iteration_index(group_start_A0); + + // cp.async for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { + if (group_start_A0 + j < Detail::TBLoadIterationsA0) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / + IteratorA0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A0.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } + } + + iterator_B0.set_iteration_index(group_start_B0 * + IteratorB0::kAccessesPerVector); + this->smem_iterator_B0_.set_iteration_index(group_start_B0); + + // cp.async for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { + if (group_start_B0 + j < Detail::TBLoadIterationsB0) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B0.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + + ++iterator_B0; + } + ++this->smem_iterator_B0_; + } + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1(IteratorB1 &iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index(group_start_B1 * + IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // cp.async for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_0, + ///< destination accumulator tile + FragmentC1 &accum, + ///< iterator over A0 operand in global memory + IteratorA0 iterator_A0, + ///< iterator over B0 operand in global memory + IteratorB0 iterator_B0, + ///< iterator over A1 operand scale vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_scale, + ///< iterator over A1 operand bias vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_bias, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC0 const &src_accum, + ///< epilogue operation after 1st Gemm + OutputOp output_op_0) + { + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_0) { + + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + + iterator_A0.set_iteration_index(0); + this->smem_iterator_A0_.set_iteration_index(0); + + // cp.async for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsA0; ++j) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / + IteratorA0::kAccessesPerVector / 8; + + int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } + + iterator_B0.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + + // cp.async for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB0; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + } + + ++this->smem_iterator_B0_; + } + + // Move to the next stage + iterator_A0.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + + Operator0 warp_mma0; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k > 0) + warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A0[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + + warp_mma0( + accum0, + warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations0 - 1) { + int group_start_iteration_A0, group_start_iteration_B0; + + group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0; + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { + int group_start_iteration_A0, group_start_iteration_B0; + group_start_iteration_A0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A0.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations_0; + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) + warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A0[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + /// Epilogue for the first Implicit Gemm + Epilogue0 epilogue0; + + epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); + + __syncthreads(); + + + // 2nd Gemm + + // + // Prologue + // + int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_1) { + + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // cp.async for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B1_; + + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + smem_write_stage_idx = Base::kStages - 1; + smem_read_stage_idx = 0; + + warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + + // Load warp-level tile from accumulator fragment + // skip warp tile loading for the last kgroup + if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B1_; + + + if (warp_mma_k > 0) + warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + + + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..28fcc94f755df181b085a5d0019ee97b64e5d340 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// FragmentIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy1_, + /// Transformation applied to A0 operand + typename TransformA0_ = NumericArrayConverter< + typename SmemIteratorA0_::Element, + typename IteratorA0_::Element, + IteratorA0_::Fragment::kElements>, + /// + /// Transformation applied to B0 operand + typename TransformB0_ = NumericArrayConverter< + typename SmemIteratorB0_::Element, + typename IteratorB0_::Element, + IteratorB0_::Fragment::kElements>, + /// + /// Transformation applied to B1 operand + typename TransformB1_ = NumericArrayConverter< + typename SmemIteratorB1_::Element, + typename IteratorB1_::Element, + IteratorB1_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class B2bMmaPipelined : + public B2bMmaBase { +public: + + ///< Base class + using Base = B2bMmaBase; + + using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA0; + using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB0; + using Policy0 = Policy0_; ///< Policy describing tuning details + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory + using FragmentIteratorA1ScaleBias = + FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment + using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory + using Policy1 = Policy1_; ///< Policy describing tuning details + using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy + using Shape = Shape1; + + using SmemIteratorB1 = SmemIteratorB1_; + + + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + + using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + using TransformA0 = TransformA0_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA0 = typename IteratorA0::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB0 = typename IteratorB0::Fragment; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB1 = typename IteratorB1::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy0::Operator::ArchTag; + + /// Complex transform on A0 operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B0 operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B1 operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + + /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA0 = typename Operator0::FragmentA; + using WarpFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; + /// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment + using WarpFragmentA1ScaleBias = + typename FragmentIteratorA1ScaleBias::Fragment; + using WarpFragmentB1 = typename Operator1::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaPipelined( + typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n ///< GEMM0 N is used for accumulator extent + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) { + + + // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + //These should stay the same across different GEMM layers + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + //These may change across different GEMM layers + int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k; + int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations_0, ///< number of iterations of the mainloop + FragmentC1 &accum, ///< destination accumulator tile + IteratorA0 iterator_A, ///< iterator over A operand in global memory + IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory + IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory + IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory + FragmentC0 const &src_accum, ///< source accumulator tile + OutputOp output_op_0, ///< epilogue operation after 1st Gemm + TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment + TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment + TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + FragmentA0 tb_frag_A; + FragmentB0 tb_frag_B0; + + tb_frag_A.clear(); + tb_frag_B0.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA0 warp_frag_A0[2]; + WarpFragmentB0 warp_frag_B0[2]; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + Operator0 warp_mma0; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations_0 <= 1); + iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + ++iterator_A; + ++iterator_B0; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations_0 <= 2); + iterator_B0.clear_mask(gemm_k_iterations_0 <= 2); + } + + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], + warp_frag_B0[warp_mma_k % 2], accum0); + } + } + + //2nd Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + + // + // Prologue + // + + FragmentA1ScaleBias tb_frag_A1_scale; + FragmentA1ScaleBias tb_frag_A1_bias; + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); + FragmentB1 tb_frag_B1; + + if(PerChannelScale) + tb_frag_A1_scale.clear(); + tb_frag_A1_bias.clear(); + tb_frag_B1.clear(); + + // The last kblock is loaded in the prolog + if(PerChannelScale) + iterator_A1_scale.load(tb_frag_A1_scale); + iterator_A1_bias.load(tb_frag_A1_bias); + iterator_B1.load(tb_frag_B1); + + if(PerChannelScale) + ++iterator_A1_scale; + ++iterator_A1_bias; + ++iterator_B1; + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + ++this->smem_iterator_B1_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA1ScaleBias warp_frag_A1_scale[2]; + WarpFragmentA1ScaleBias warp_frag_A1_bias[2]; + WarpFragmentA1 warp_frag_A1[2]; + WarpFragmentB1 warp_frag_B1[2]; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + if(PerChannelScale) + warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]); + warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]); + warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0], + warp_frag_A1_bias[0], output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + + ++warp_tile_iterator_A1_; + if(PerChannelScale) + ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; + ++this->warp_tile_iterator_B1_; + + Operator1 warp_mma1; + + smem_write_stage_idx = 1; + + int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + + // Avoid reading out of bounds + iterator_B1.clear_mask(gemm_k_iterations_1 <= 1); + + // + // Mainloop + // + + // Note: The main loop does not support Base::WarpGemmIterations == 2. + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + + // Write fragments to shared memory + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + __syncthreads(); + ++this->smem_iterator_B1_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + } + + smem_write_stage_idx ^= 1; + + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + } + + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + + if(PerChannelScale) + warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], + warp_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + + if(PerChannelScale) + ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k == 0) { + + iterator_B1.load(tb_frag_B1); + ++iterator_B1; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B1.clear_mask(gemm_k_iterations_1 <= 2); + } + + warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], + warp_frag_B1[warp_mma_k % 2], accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..b3754945b3c14ef563603850945914357d7a00b8 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h @@ -0,0 +1,552 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base_smem_accumulator.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// Iterates over accumulator tile + typename FragmentIteratorAccumulator_, + /// Iterates over accumulator tile in shared memory + typename SmemIteratorD0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy1_, + /// Transformation applied to A0 operand + typename TransformA0_ = NumericArrayConverter< + typename SmemIteratorA0_::Element, + typename IteratorA0_::Element, + IteratorA0_::Fragment::kElements>, + /// + /// Transformation applied to B0 operand + typename TransformB0_ = NumericArrayConverter< + typename SmemIteratorB0_::Element, + typename IteratorB0_::Element, + IteratorB0_::Fragment::kElements>, + /// + /// Transformation applied to B1 operand + typename TransformB1_ = NumericArrayConverter< + typename SmemIteratorB1_::Element, + typename IteratorB1_::Element, + IteratorB1_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class B2bMmaPipelinedSmemAccumulator : + public B2bMmaBaseSmemAccumulator { +public: + + ///< Base class + using Base = B2bMmaBaseSmemAccumulator; + + using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA0; + using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB0; + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory + using Policy0 = Policy0_; ///< Policy0 describing tuning details + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory + + using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile + + using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory + using Policy1 = Policy1_; ///< Policy1 describing tuning details + using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy + using Shape = Shape1; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory + + + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + + using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + + using TransformA0 = TransformA0_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA0 = typename IteratorA0::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB0 = typename IteratorB0::Fragment; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of operand B loaded from global memory + using FragmentB1 = typename IteratorB1::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy0::Operator::ArchTag; + + /// Complex transform on A0 operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B0 operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B1 operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + + /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + + /// Epilog in shared memory + using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, ///< SmemTileIterator + FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator + IteratorAccumulatorScaleBias, ///< ScaleBiasIterator + OutputOp>; ///< Output operator + + + +private: + + using WarpFragmentA0 = typename Operator0::FragmentA; + using WarpFragmentB0 = typename Operator0::FragmentB; + using WarpFragmentA1 = typename Operator1::FragmentA; + using WarpFragmentB1 = typename Operator1::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Shared Memory Iterator to store accumulator tile + SmemIteratorD0 smem_iterator_D0_; + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaPipelinedSmemAccumulator( + typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n ///< GEMM0 N is used for accumulator extent + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx), + smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; + int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; + + int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0; + + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0}); + warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1}); + this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1}); + + // Add smem accumulator iterator warp offset + smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, + warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); + + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations_0, ///< number of iterations of the mainloop + FragmentC1 &accum, ///< destination accumulator tile + IteratorA0 iterator_A, ///< iterator over A operand in global memory + IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory + IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory + FragmentC0 const &src_accum, ///< source accumulator tile + OutputOp output_op_0, ///< epilogue operation after 1st Gemm + TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment + TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment + TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + FragmentA0 tb_frag_A; + FragmentB0 tb_frag_B0; + + tb_frag_A.clear(); + tb_frag_B0.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA0 warp_frag_A0[2]; + WarpFragmentB0 warp_frag_B0[2]; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + Operator0 warp_mma0; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations_0 <= 1); + iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + ++iterator_A; + ++iterator_B0; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations_0 <= 2); + iterator_B0.clear_mask(gemm_k_iterations_0 <= 2); + } + + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], + warp_frag_B0[warp_mma_k % 2], accum0); + } + } + + /// Epilogue for the first Implicit Gemm + Epilogue0 epilogue0; + + epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); + + __syncthreads(); + + //2nd Gemm + + // + // Prologue + // + + FragmentB1 tb_frag_B1; + + tb_frag_B1.clear(); + + // The last kblock is loaded in the prolog + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + ++this->smem_iterator_B1_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA1 warp_frag_A1[2]; + WarpFragmentB1 warp_frag_B1[2]; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + warp_tile_iterator_A1_.load(warp_frag_A1[0]); + this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + Operator1 warp_mma1; + + smem_write_stage_idx = 1; + + int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; + + // Avoid reading out of bounds + iterator_B1.clear_mask(gemm_k_iterations_1 <= 1); + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + + // Write fragments to shared memory + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + __syncthreads(); + + ++this->smem_iterator_B1_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + } + + smem_write_stage_idx ^= 1; + + } + + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + + // skip warp tile loading for the last kgroup + if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1) + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k == 0) { + + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B1.clear_mask(gemm_k_iterations_1 <= 2); + } + + warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], + warp_frag_B1[warp_mma_k % 2], accum); + + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..cbbc24a8f34ad9cbe64d39205a918b2bd3bfd040 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h @@ -0,0 +1,584 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_pipelined.h" +#include "threadblock/b2b_mma_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Staging the accumulators in shared memory. + bool SmemAccumulator = false> +struct DefaultB2bMma; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output with 2-stage pipeline +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp> +struct DefaultB2bMma { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, 2, Operator>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>; + + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + IteratorB0, typename MmaCore0::SmemIteratorB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, + IteratorB1, typename MmaCore1::SmemIteratorB, + ElementAccumulator, layout::RowMajor, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; + +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output for multi-stage +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp> +struct DefaultB2bMma { + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using AccessTypeA0 = cutlass::Array; + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using AccessTypeB0 = cutlass::Array; + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using AccessTypeB1 = cutlass::Array; + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + MmaCore0::kCacheOpA, + IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, + IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, + ElementAccumulator, layout::RowMajor, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column-major-interleaved output with 2-stage pipeline +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Number of Interleaved K + int InterleavedK> +struct DefaultB2bMma, OperatorClass, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, 2, Operator, EpilogueOutputOp, true> { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + + static_assert(kAlignmentA == 128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + static_assert(kAlignmentB ==128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + // Define iterators over tiles from the A operand + using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, + LayoutA, 1, typename MmaCore0::IteratorThreadMapA>; + + // Define iterators over tiles from the B operand + using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, + LayoutB, 0, typename MmaCore0::IteratorThreadMapB>; + + // Use fragment iterator for A1 operand + using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, + InstructionShape, EpilogueOutputOp>; + + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; + + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + IteratorB0, typename MmaCore0::SmemIteratorB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, + IteratorB1, typename MmaCore1::SmemIteratorB, + ElementAccumulator, layout::ColumnMajorInterleaved, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column-major-interleaved output with multi-stage +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Number of Interleaved K + int InterleavedK> +struct DefaultB2bMma, OperatorClass, ArchTag, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp, true> { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>; + + // Use fragment iterator for A1 operand + using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, + InstructionShape, EpilogueOutputOp>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>; + + + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + MmaCore0::kCacheOpA, + IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, + IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, + ElementAccumulator, layout::ColumnMajorInterleaved, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..a848f5c43995682d3e21d70a5cdb12b3855ae4f9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h @@ -0,0 +1,605 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" + +#include "threadblock/b2b_mma_pipelined_smem_accumulator.h" +#include "threadblock/b2b_mma_multistage_smem_accumulator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output with 2-stage pipeline +/// Accumulator will be staged in shared memory. +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp> +struct DefaultB2bMma { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, 2, Operator>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>; + + // Define iterators over tiles from the B operand + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + typename EpilogueOutputOp::ElementOutput, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + IteratorB0, typename MmaCore0::SmemIteratorB, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, SmemIteratorD0, + typename MmaCore1::Shape, WarpIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, + ElementAccumulator, layout::RowMajor, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; + +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output for multi-stage +/// Accumulator will be staged in shared memory. +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp> +struct DefaultB2bMma { + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using AccessTypeA0 = cutlass::Array; + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using AccessTypeB0 = cutlass::Array; + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using AccessTypeB1 = cutlass::Array; + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + typename EpilogueOutputOp::ElementOutput, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + MmaCore0::kCacheOpA, + IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, SmemIteratorD0, + typename MmaCore1::Shape, WarpIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, + ElementAccumulator, layout::RowMajor, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column-major-interleaved output with 2-stage pipeline +/// Accumulator will be staged in shared memory. +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Number of Interleaved K + int InterleavedK> +struct DefaultB2bMma, OperatorClass, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, 2, Operator, EpilogueOutputOp, true, true> { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + + static_assert(kAlignmentA == 128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + static_assert(kAlignmentB ==128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + // Define iterators over tiles from the A operand + using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, + LayoutA, 1, typename MmaCore0::IteratorThreadMapA>; + + // Define iterators over tiles from the B operand + using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, + LayoutB, 0, typename MmaCore0::IteratorThreadMapB>; + + // Define iterators over tiles from the B operand + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; //For interleaved layout + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + typename EpilogueOutputOp::ElementOutput, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + IteratorB0, typename MmaCore0::SmemIteratorB, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, SmemIteratorD0, + typename MmaCore1::Shape, WarpIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, + ElementAccumulator, layout::ColumnMajorInterleaved, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; + +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for column-major-interleaved output with multi-stage +/// Accumulator will be staged in shared memory. +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Number of Interleaved K + int InterleavedK> +struct DefaultB2bMma, OperatorClass, ArchTag, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp, true, true> { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + typename EpilogueOutputOp::ElementOutput, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true >; + + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + MmaCore0::kCacheOpA, + IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, SmemIteratorD0, + typename MmaCore1::Shape, WarpIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, + ElementAccumulator, layout::ColumnMajorInterleaved, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h new file mode 100644 index 0000000000000000000000000000000000000000..e033138057e6ddc5b316822e18a8f473c2f59913 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h @@ -0,0 +1,125 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements several threadblock-swizzling functions for grouped kernels +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "kernel/b2b_gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +struct GroupedThreadblockSwizzleBase {}; + +/// Helper for determining if a swizzling function is specialized for grouped operation +template +struct IsGroupedSwizzle { + static bool const value = cutlass::platform::is_base_of::value; +}; + +} // namespace detail + +/// Swizzling function for grouped kernels +template +struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase { + + using ProblemVisitor = ProblemVisitor_; + ProblemVisitor problem_visitor; + + CUTLASS_HOST_DEVICE + GroupedThreadblockSwizzle(typename ProblemVisitor::Params& params, + typename ProblemVisitor::SharedStorage& shared_storage, + int block_idx) : problem_visitor(params, shared_storage, block_idx) {} + + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + GemmCoord get_tile_offset(int /*log_tile*/) const { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + return GemmCoord(int(threadblock_idx / grid_shape.n()), + int(threadblock_idx % grid_shape.n()), + 0); + } + + /// Dummy method to satisfy API for threadblock swizzling functions + CUTLASS_HOST_DEVICE + static int get_log_tile(GemmCoord /*tiled_shape*/) { + return 0; + } +}; + +template < + typename ThreadblockShape, + typename LayoutC, + cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + int PrefetchTileCount = 128, + int ThreadCount = PrefetchTileCount> +struct B2bGemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle< + cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + platform::is_same::value + > + > { + using Base = GroupedThreadblockSwizzle::value>>; + + CUTLASS_HOST_DEVICE + B2bGemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params, + typename Base::ProblemVisitor::SharedStorage& shared_storage, + int block_idx) : Base(params, shared_storage, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..fc8f96c1fd8b0b485d48e2873afb52a0f8c2516b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h @@ -0,0 +1,536 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmWithEpilogueVisitor { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + using ElementNorm = typename EpilogueVisitor::ElementNorm; + using ElementSum = typename EpilogueVisitor::ElementSum; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value + ); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + + ElementNorm *ptr_Max; + ElementSum *ptr_Sum; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1) + { } + + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode_, + GemmCoord problem_size_, + int batch_count_, + TensorRefA ref_A_, + TensorRefB ref_B_, + TensorRefC ref_C_, + TensorRefC ref_D_, + ElementNorm *ptr_Max_, + ElementSum *ptr_Sum_, + int64_t batch_stride_A_, + int64_t batch_stride_B_, + typename EpilogueVisitor::Arguments epilogue_visitor_ + ): + mode(mode_), + problem_size(problem_size_), + batch_count(batch_count_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + ptr_Max(ptr_Max_), + ptr_Sum(ptr_Sum_), + batch_stride_A(batch_stride_A_), + batch_stride_B(batch_stride_B_), + epilogue_visitor(epilogue_visitor_) + { + + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + ElementC * ptr_C; + ElementC * ptr_D; + + ElementNorm * ptr_Max; + ElementSum * ptr_Sum; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_Max(nullptr), + ptr_Sum(nullptr), + batch_stride_A(0), + batch_stride_B(0) + { } + + + Params( + Arguments const &args + ): + problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + ptr_Max(args.ptr_Max), + ptr_Sum(args.ptr_Sum), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + epilogue_visitor(args.epilogue_visitor) + { + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + + typename Mma::SharedStorage main_loop; + + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + #define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + + #if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + #endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + params.params_C, + params.params_D, + params.ptr_C, + params.ptr_D, + params.ptr_Max, + params.ptr_Sum, + threadblock_offset, + blockIdx.y *params.problem_size.m() ); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..31b2b76940cd06b6ec54a647acc7943064cefe4b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h @@ -0,0 +1,663 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" + +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" +#include "cutlass/reduction/kernel/reduce_softmax_final.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "gemm_with_epilogue_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Kernel computes partial reduction +// +// +// 2. Sum[m, n'] = sum_n(exp(D[m, n] - N[m, 0])) +// +template < + typename ElementD_, + typename ElementNorm_, + typename ElementSum_, + typename ElementSoft_, + typename ElementSoftmaxCompute_, + int Alignment, + typename ApplyShape_ = MatrixShape<1, 1024> +> +class ApplySoftmax { +public: + + using ElementD = ElementD_; + using ElementNorm = ElementNorm_; + using ElementSum = ElementSum_; + using ElementSoft = ElementSoft_; + using ElementSoftmaxCompute = ElementSoftmaxCompute_; + + static int const kAlignment = Alignment; + using ApplyShape = ApplyShape_; + + using Layout = cutlass::layout::RowMajor; + + using TensorRefD = TensorRef; + using TensorRefN = TensorRef; + using TensorRefSum = TensorRef; + using TensorRefSoft = TensorRef; + + using FragmentSoftmax = Array; + + // + // Arguments + // + + struct Arguments { + + MatrixCoord extent; ///< Extent of D and Softmax matrices + int batch_count; ///< Batch count + TensorRefD ref_D; ///< D matrix computed by GEMM+Max (input) + TensorRefN ref_N; ///< Norm tensor (input) + TensorRefSum ref_S; ///< Sum tensor (input) + TensorRefSoft ref_Soft; ///< Softmax tensor (output) + int64_t batch_stride_D; ///< Batch stride for D tensor + int64_t batch_stride_N; ///< Batch stride for N tensor + int64_t batch_stride_S; ///< Batch stride for S tensor + int64_t batch_stride_Soft; ///< Batch stride for softmax tensor + + // + // Methods + // + Arguments(): + batch_count(1), + batch_stride_D(0), + batch_stride_N(0), + batch_stride_S(0), + batch_stride_Soft(0) + { } + + Arguments( + MatrixCoord extent_, ///< Extent of D and Softmax matrices + int batch_count_, ///< Batch count + TensorRefD ref_D_, ///< D matrix computed by GEMM+PartialReduce + TensorRefN ref_N_, ///< Output parameter for N + TensorRefSum ref_S_, ///< Output parameter for N + TensorRefSoft ref_Soft_, ///< Softmax + int64_t batch_stride_D_ = 0, + int64_t batch_stride_N_ = 0, + int64_t batch_stride_S_ = 0, + int64_t batch_stride_Soft_ = 0 + ): + extent(extent_), + batch_count(batch_count_), + ref_D(ref_D_), + ref_N(ref_N_), + ref_S(ref_S_), + ref_Soft(ref_Soft_), + batch_stride_D(batch_stride_D_), + batch_stride_N(batch_stride_N_), + batch_stride_S(batch_stride_S_), + batch_stride_Soft(batch_stride_Soft_) + { + + } + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + + // + // SharedStorage + // + + struct SharedStorage { + + }; + +private: + +public: + + CUTLASS_DEVICE + ApplySoftmax() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + apply(params, shared_storage); + } + +private: + + + /// Compute Softmax + CUTLASS_DEVICE + void apply(Params const ¶ms, SharedStorage &shared_storage) { + + using AccessTypeD = AlignedArray; + + int block_batch = blockIdx.z; + int block_m = blockIdx.x * ApplyShape::kRow; + int block_n = 0; + + int thread_m = threadIdx.y; + int thread_n = threadIdx.x * kAlignment; + + int idx_m = block_m + thread_m; + int idx_n = block_n + thread_n; + + int batch_offset_norm = block_batch * params.args.batch_stride_N; + int batch_offset_sum = block_batch * params.args.batch_stride_S; + + // Kill off thread if it is outside the row boundary + if (params.args.extent.row() <= idx_m) { + return; + } + + // + // Setup pointers to load D again + // + + using AccessTypeD = AlignedArray; + using AccessTypeSoft = AlignedArray; + using FragmentSoft = Array; + using ConvertSoftCompute = cutlass::NumericArrayConverter; + using ConvertSoftOutput = cutlass::NumericArrayConverter; + + using Mul = cutlass::multiplies; + using Minus = cutlass::minus; + using Exp = cutlass::fast_exp_op; + + ConvertSoftCompute convert_soft_compute; + ConvertSoftOutput convert_soft_output; + + Minus minus; + Mul mul; + Exp exponential; + + using ConvertSum = cutlass::NumericConverter; + using ConvertNorm = cutlass::NumericConverter; + + ConvertSum convert_sum; + ConvertNorm convert_norm; + + AccessTypeD *access_d = reinterpret_cast( + params.args.ref_D.data() + + params.args.batch_stride_D * block_batch + + params.args.ref_D.layout()({idx_m, idx_n})); + + AccessTypeSoft *access_soft = reinterpret_cast( + params.args.ref_Soft.data() + + params.args.batch_stride_Soft * block_batch + + params.args.ref_Soft.layout()({idx_m, idx_n})); + + ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum]; + ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm]; + + // + // Loop + // + CUTLASS_PRAGMA_UNROLL + for ( + int idx = 0; + idx < params.args.extent.column(); + idx += ApplyShape::kColumn * kAlignment) { + + if (idx_n < params.args.extent.column()) { + AccessTypeD fetch; + arch::global_load(fetch, access_d, true); + + FragmentSoftmax result = mul(exponential(minus(convert_soft_compute(fetch), convert_norm(norm))), convert_sum(inv_sum)); + FragmentSoft soft = convert_soft_output(result); + + arch::global_store(soft, access_soft, true); + } + + access_d += ApplyShape::kColumn; + access_soft += ApplyShape::kColumn; + idx_n += ApplyShape::kColumn * kAlignment; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename ElementCompute_, + typename OperatorClass_, + typename ArchTag_, + typename ThreadblockShape_, + typename WarpShape_, + typename InstructionShape_, + typename EpilogueFunctorOp_, + int kStages_, + typename ApplyShape_ = MatrixShape<1, 1024>, + int AlignmentA_ = 128 / cutlass::sizeof_bits::value, + int AlignmentB_ = 128 / cutlass::sizeof_bits::value, + int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits::value, + typename ElementNorm_ = float, + typename ElementSum_ = float, + typename ElementSoftmax_ = ElementC_ +> +class GemmSoftmax { +public: + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // + // Type definitions + // + + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementC = ElementC_; + using ElementCompute = ElementCompute_; + using ElementSum = ElementSum_; + using ElementSoft = ElementSoftmax_; + using ElementSoftmaxCompute = float; + + using LayoutA = LayoutA_; + using LayoutB = LayoutB_; + + using EpilogueFunctorOp = EpilogueFunctorOp_; + using ElementNorm = ElementNorm_; + + using ApplyShape = ApplyShape_; + + // These are mandatory layouts. + using LayoutC = cutlass::layout::RowMajor; + using LayoutN = cutlass::layout::RowMajor; + using LayoutS = cutlass::layout::RowMajor; + using LayoutSoft = cutlass::layout::RowMajor; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorRefN = TensorRef; + using TensorRefSum = TensorRef; + using TensorRefSoft = TensorRef; + + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + + static int const kStages = kStages_; + static int const AlignmentA = AlignmentA_; + static int const AlignmentB = AlignmentB_; + static int const AlignmentSoftmax = AlignmentSoftmax_; + + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // basic GEMM kernel + using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementC, + LayoutC, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + ThreadblockSwizzle, + kStages, + true, + typename cutlass::gemm::device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementA, ElementB, ElementC, ElementCompute>::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone + >::GemmKernel; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // Epilogue visitor + using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax< + ThreadblockShape, + DefaultGemmKernel::kThreadCount, + typename DefaultGemmKernel::Epilogue::OutputTileIterator, + ElementCompute, + ElementNorm, + ElementSum, + ElementSoftmaxCompute, + EpilogueFunctorOp + >; + + /// Epilogue + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, + typename DefaultGemmKernel::Epilogue + >::Epilogue; + + // GEMM + using GemmKernel = gemm::kernel::GemmWithEpilogueVisitor< + typename DefaultGemmKernel::Mma, + Epilogue, + ThreadblockSwizzle + >; + + // Softmax kernel + using SoftmaxApplyKernel = kernel::ApplySoftmax< + ElementC, + ElementNorm, + ElementSum, + ElementSoft, + ElementSoftmaxCompute, + AlignmentSoftmax, + ApplyShape + >; + + using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< + ElementNorm, + ElementSum, + ElementSoftmaxCompute, + ThreadblockShape + >; + +public: + + /// Arguments class + struct Arguments { + + typename GemmKernel::Arguments gemm; + typename SoftmaxApplyKernel::Arguments softmax; + typename ApplyFinalReductionKernel::Arguments reduction; + cutlass::gemm::GemmCoord extend; + + // + // Methods + // + Arguments() { } + + Arguments( + cutlass::gemm::GemmCoord problem_size, + int32_t batch_count_, + TensorRefA ref_A_, + TensorRefB ref_B_, + TensorRefC ref_C_, + TensorRefC ref_D_, + typename EpilogueFunctorOp::Params linear_scaling, + TensorRefN ref_N_, + TensorRefSum ref_S_, + TensorRefSoft ref_Softmax_, + int64_t batch_stride_A_ = 0, + int64_t batch_stride_B_ = 0, + int64_t batch_stride_C_ = 0, + int64_t batch_stride_D_ = 0, + int64_t batch_stride_Max_ = 0, + int64_t batch_stride_Sum_ = 0, + int64_t batch_stride_Softmax_ = 0 + ): + gemm( + cutlass::gemm::GemmUniversalMode::kBatched, + problem_size, + batch_count_, + ref_A_, + ref_B_, + ref_C_, + ref_D_, + ref_N_.data(), + ref_S_.data(), + batch_stride_A_, + batch_stride_B_, + typename EpilogueVisitor::Arguments( + linear_scaling, + batch_stride_C_, + batch_stride_D_, + batch_stride_Max_, + batch_stride_Sum_ + ) + ), + reduction( + problem_size, + ref_N_.data(), + ref_S_.data(), + batch_stride_Max_, + batch_stride_Sum_ + ), + softmax( + MatrixCoord(problem_size.m(), problem_size.n()), + batch_count_, + ref_D_, + ref_N_, + ref_S_, + ref_Softmax_, + batch_stride_D_, + batch_stride_Max_, + batch_stride_Sum_, + batch_stride_Softmax_ + ), + extend(problem_size) + { + + } + }; + + struct Params { + + typename GemmKernel::Params gemm; + typename SoftmaxApplyKernel::Params softmax; + typename ApplyFinalReductionKernel::Params reduction; + MatrixCoord extend; + // + // Methods + // + Params() { } + + Params(Arguments const &args): + gemm(args.gemm), + reduction(args.reduction), + softmax(args.softmax), + extend(MatrixCoord(args.extend.m(), args.extend.n())) + { + + } + }; + +public: + + // Gemm + + + // + // Methods + // + +private: + + Params params_; + +public: + + /// Ctor + GemmSoftmax() { + + } + + /// Initialize + Status initialize(Arguments const &args) { + + params_ = Params(args); + + return cutlass::Status::kSuccess; + } + + /// Run + Status run(cudaStream_t stream) { + + // + // Launch the GEMM + max kernel + // + + dim3 gemm_grid = ThreadblockSwizzle().get_grid_shape(params_.gemm.grid_tiled_shape); + dim3 gemm_block(GemmKernel::kThreadCount, 1, 1); + + int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + cudaError_t result; + + if (gemm_smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + gemm_smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_.gemm); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + + // + // Launch the ApplyFinalReductionKernel + // + + int thread_per_block = 128; + int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; + if (block_per_row < 4) { + thread_per_block = 32; + block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; + } + + dim3 final_reduction_grid(block_per_row, 1, params_.softmax.args.batch_count); + dim3 final_reduction_block(thread_per_block); + + Kernel<<< + final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream + >>>(params_.reduction); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the SoftmaxApplyKernel + // + + dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow); + + int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow; + int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment; + + dim3 apply_grid( + (params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows, + (params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns, + params_.softmax.args.batch_count); + + Kernel<<< + apply_grid, apply_block, sizeof(typename SoftmaxApplyKernel::SharedStorage), stream + >>>(params_.softmax); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + return cutlass::Status::kSuccess; + } + + /// Function call operator + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..5b36be05853718518851b4967f769b72697bccad --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h @@ -0,0 +1,444 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized layernorm partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmWithEpilogueVisitor { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value + ); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + + TensorRefA ref_A; + TensorRefB ref_B; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm) + { } + + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode_, + GemmCoord problem_size_, + TensorRefA ref_A_, + TensorRefB ref_B_, + typename EpilogueVisitor::Arguments epilogue_visitor_ + ): + mode(mode_), + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + epilogue_visitor(epilogue_visitor_) + { + + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + + GemmUniversalMode mode; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr) + { } + + + Params( + Arguments const &args + ): + problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + mode(args.mode), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + epilogue_visitor(args.epilogue_visitor) + { + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 1); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(args.problem_size.k(), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + + typename Mma::SharedStorage main_loop; + + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + threadblock_offset); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h new file mode 100644 index 0000000000000000000000000000000000000000..813f75604ba2bc3a3ce145e1e490d6a607de1981 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h @@ -0,0 +1,1066 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A file contains all functioning classes needed by GemmLayernorm. + + GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm) + + lightweight full reduction kernel (ApplyFinalReduction) + + GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion) + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "gemm_with_epilogue_visitor.h" +#include "helper.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementVariance_, + typename ElementMean_, + typename ElementLayernormCompute_, + typename ElementOutput, + typename ThreadblockShape_, + bool IsShiftedVariance_ = false +> +class ApplyFinalReduction { +public: + + using ElementVariance = ElementVariance_; + using ElementMean = ElementMean_; + using ElementLayernormCompute = ElementLayernormCompute_; + using ThreadblockShape = ThreadblockShape_; + + // Pre-processing has ensured the layout equivalent to RowMajor + using Layout = cutlass::layout::RowMajor; + + using TensorVariance = TensorRef; + using TensorMean = TensorRef; + + static bool const kIsShiftedVariance = IsShiftedVariance_; + + // + // Arguments + // + + struct Arguments { + + MatrixCoord extent; ///< Extent of D and Layernorm matrices + TensorVariance ref_Variance; ///< Sum Square or Variance tensor (input / output) + TensorMean ref_Mean; ///< Sum or Mean tensor (input / output) + ElementOutput *ptr_Shifted_K; ///< Shifted K tensor pointer + + // + // Methods + // + Arguments(){ } + + Arguments( + MatrixCoord extent_, + TensorVariance ref_Variance_, + TensorMean ref_Mean_, + ElementOutput *ptr_Shifted_K_ + ): + extent(extent_), + ref_Variance(ref_Variance_), + ref_Mean(ref_Mean_), + ptr_Shifted_K(ptr_Shifted_K_) + { + + } + }; + + struct SharedStorage { + + + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + +private: + +public: + + CUTLASS_DEVICE + ApplyFinalReduction() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + apply(params, shared_storage); + } + +private: + + /// Partial reduction + CUTLASS_DEVICE + void apply(Params const ¶ms, SharedStorage &shared_storage) { + + int threadblock_num = (params.args.extent.column() + ThreadblockShape::kM - 1) / ThreadblockShape::kM; + + int block_n = blockIdx.x * blockDim.x; + + int thread_n = threadIdx.x; + + int idx_n = block_n + thread_n; + + if (idx_n >= params.args.extent.row()) { + return; + } + + using ConvertVarianceOutput = cutlass::NumericConverter; + using ConvertMeanOutput = cutlass::NumericConverter; + + using ConvertVariance = cutlass::NumericConverter; + using ConvertMean = cutlass::NumericConverter; + + using ConvertShiftK = cutlass::NumericConverter; + + ConvertVariance convert_variance; + ConvertMean convert_mean; + + ConvertVarianceOutput convert_variance_output; + ConvertMeanOutput convert_mean_output; + + ElementVariance *access_square = params.args.ref_Variance.data() + idx_n; + ElementMean *access_mean = params.args.ref_Mean.data() + idx_n; + + ElementVariance *access_square_bak = access_square; + ElementMean *access_mean_bak = access_mean; + + ElementLayernormCompute frag_square_sum = ElementLayernormCompute(0); + ElementLayernormCompute frag_element_sum = ElementLayernormCompute(0); + ElementVariance fetch_square; + ElementMean fetch_mean; + + CUTLASS_PRAGMA_UNROLL + for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { + arch::global_load(fetch_square, access_square, true); + arch::global_load(fetch_mean, access_mean, true); + frag_element_sum += convert_mean(fetch_mean); + frag_square_sum += convert_variance(fetch_square); + access_square += params.args.extent.row(); + access_mean += params.args.extent.row(); + } + + ElementLayernormCompute mean = frag_element_sum; + ElementLayernormCompute square_mean = frag_square_sum; + + ElementLayernormCompute variance; + + if (kIsShiftedVariance && params.args.ptr_Shifted_K != nullptr) { + ElementOutput *access_shift_k = params.args.ptr_Shifted_K + idx_n; + ElementOutput fetch_shift_k; + ConvertShiftK convert_shift_k; + arch::global_load(fetch_shift_k, access_shift_k, true); + ElementLayernormCompute shifted_mean = mean - convert_shift_k(fetch_shift_k); + variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - shifted_mean * shifted_mean + ElementLayernormCompute(1e-6)); + }else{ + variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)); + } + + mean = -mean * variance; + + access_square = access_square_bak; + access_mean = access_mean_bak; + + access_square[0] = convert_variance_output(variance); + access_mean[0] = convert_mean_output(mean); + + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ThreadblockShape_, + int ThreadCount, + typename OutputTileIterator_, + typename AccumulatorTile_, + typename ElementAccumulator_, + typename ElementVariance_, + typename ElementMean_, + typename ElementLayernormCompute_, + typename ElementwiseFunctor_, + bool IsShiftedVariance_ = false +> +class EpilogueVisitorLayerNorm { +public: + + using ElementVariance = ElementVariance_; + using ElementMean = ElementMean_; + using ElementLayernormCompute = ElementLayernormCompute_; + + using AccumulatorTile = AccumulatorTile_; + + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow; + + static int const kThreads = OutputTileIterator::ThreadMap::kThreads; + + static bool const kIsShiftedVariance = IsShiftedVariance_; + + using ElementOutput = typename OutputTileIterator::Element; + + static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow; + + /// Array type used in Shift-K Layernorm + static int const kRowAccessCount = kIterations * kRowIterations; + + using ConvertedShiftFragment = Array; + + // Conducts manual transpose externally (already supported) for column major + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementAccumulator = ElementAccumulator_; + + using AccumulatorFragment = Array; + using LayernormFragment = Array; + using OutputVector = Array; + using TensorRefD = TensorRef; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth; + static int const kThreadsInColumn = kThreads / kThreadsPerRow; + static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); + + /// Argument structure + struct Arguments { + + typename ElementwiseFunctor::Params elementwise; + TensorRefD ref_C; + TensorRefD ref_D; + ElementVariance *ptr_Variance; + ElementMean *ptr_Mean; + ElementOutput *ptr_Shifted_K; + + // + // Methods + // + Arguments(): + ptr_Variance(nullptr), + ptr_Mean(nullptr), + ptr_Shifted_K(nullptr) + { + + } + + Arguments( + typename ElementwiseFunctor::Params elementwise_, + TensorRefD ref_C_, + TensorRefD ref_D_, + ElementVariance *ptr_Variance, + ElementMean *ptr_Mean_, + ElementOutput *ptr_Shifted_K_ = nullptr + ): + elementwise(elementwise_), + ref_C(ref_C_), + ref_D(ref_D_), + ptr_Variance(ptr_Variance), + ptr_Mean(ptr_Mean_), + ptr_Shifted_K(ptr_Shifted_K_) + { + + } + }; + + struct Params { + + typename ElementwiseFunctor::Params elementwise; + typename OutputTileIterator::Params params_C; + typename OutputTileIterator::Params params_D; + typename OutputTileIterator::Element *ptr_C; + typename OutputTileIterator::Element *ptr_D; + ElementVariance *ptr_Variance; + ElementMean *ptr_Mean; + ElementOutput *ptr_Shifted_K; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(): + ptr_D(nullptr), + ptr_Variance(nullptr), + ptr_Mean(nullptr) + { + + } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + elementwise(args.elementwise), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + ptr_Variance(args.ptr_Variance), + ptr_Mean(args.ptr_Mean), + ptr_Shifted_K(args.ptr_Shifted_K) + { + + } + }; + + /// Shared storage + struct SharedStorage { + + }; + +private: + + Params const & params_; + SharedStorage & shared_storage_; + MatrixCoord extent_; + ElementwiseFunctor elementwise_; + + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator alpha_; + ElementAccumulator beta_; + ConvertedShiftFragment shift_k_frag_; + + ElementLayernormCompute accum_sum_square_; + ElementLayernormCompute accum_sum_element_; + + MatrixCoord thread_offset_; + +public: + + CUTLASS_DEVICE + EpilogueVisitorLayerNorm( + Params const ¶ms, ///< Parameters routed to the epilogue + SharedStorage &shared_storage, ///< Shared storage needed by the functors here + MatrixCoord const &problem_size0, ///< Problem size of the output + int thread_idx, ///< Thread index within the threadblock + int warp_idx, ///< Warp index within the threadblock + int lane_idx, ///< Lane index within the warp + MatrixCoord const &threadblock_offset = MatrixCoord(0, 0) + ): + params_(params), + shared_storage_(shared_storage), + extent_(problem_size0), + elementwise_(params.elementwise), + iterator_C_(params.params_C, params.ptr_C, problem_size0, thread_idx, threadblock_offset), + iterator_D_(params.params_D, params.ptr_D, problem_size0, thread_idx, threadblock_offset) + { + alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + + // If shift-K feature is enabled, we load shift-k fragment + // at the very beginning of an epilogue + if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) { + shift_k_frag_.clear(); + int thread_offset_row_base = iterator_D_.thread_start_row(); + + CUTLASS_PRAGMA_UNROLL + for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { + int step_offset = iter_idx * OutputTileIterator::Shape::kRow; + CUTLASS_PRAGMA_UNROLL + for (int rid = 0; rid < kRowIterations; ++rid) { + int row_step_offset = rid * kDeltaRow; + int row_offset = thread_offset_row_base + step_offset + row_step_offset; + bool is_load = (row_offset < extent_.row()); + shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load); + } + + } + + } + + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + fragment_C_.clear(); + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorFragment const &accum) { + + using Mul = cutlass::multiplies; + using Minus = cutlass::minus; + using Exp = cutlass::fast_exp_op; + + [[maybe_unused]] Minus minus; + [[maybe_unused]] Mul mul; + [[maybe_unused]] Exp exponential; + + LayernormFragment result; + + thread_offset_ = + iterator_D_.thread_start() + + OutputTileIterator::ThreadMap::iteration_offset(frag_idx); + + NumericArrayConverter source_converter; + OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; + + bool column_guard = (thread_offset_.column() < extent_.column()); + + if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + result = source_converter(elementwise_(accum)); + }else{ + result = source_converter(elementwise_(accum, source_vector)); + } + + + ElementLayernormCompute inv_scalar = cutlass::constants::one() / ElementLayernormCompute(extent_.column()); + + // Fragment is cleared for non-reachable columns so no need to check against column guard + accum_sum_element_ = element_sum_accumulator_(result); + + // Square sum is different. Non-reachable columns should've been computed for shift-k + // Otherwise we will incorrectly have some extra k^2 added into square sum. + if (column_guard) { + accum_sum_square_ = (kIsShiftedVariance) ? \ + square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \ + square_sum_accumulator_(result); + } + else { + accum_sum_square_ = ElementLayernormCompute(0); + } + + accum_sum_element_ *= inv_scalar; + accum_sum_square_ *= inv_scalar; + + // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction + CUTLASS_PRAGMA_UNROLL + for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) { + accum_sum_element_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_, i); + accum_sum_square_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_, i); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the start of a row + CUTLASS_DEVICE + void end_row(int row_idx) { + + using ConvertVarianceOutput = cutlass::NumericConverter; + using ConvertMeanOutput = cutlass::NumericConverter; + + ConvertVarianceOutput convert_variance_output; + ConvertMeanOutput convert_mean_output; + + bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0); + int row_offset = thread_offset_.row() + blockIdx.y * extent_.row(); + + ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset; + ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset; + + arch::global_store( + convert_variance_output(accum_sum_square_), + (void *)curr_ptr_sum_square, + is_write_thread); + + arch::global_store( + convert_mean_output(accum_sum_element_), + (void *)curr_ptr_element_sum, + is_write_thread); + + } + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() { + + } + +private: + + CUTLASS_DEVICE + ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) { + using ConvertShiftK = cutlass::NumericConverter; + ConvertShiftK convert_shift_k; + ElementOutput shift_k_val; + + // Computes the address to load shift_k element + ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset; + // Conditionally loads from global memory + arch::global_load(shift_k_val, (void *)curr_ptr_shift_k, is_load); + // Converts data type to return + ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val); + + return converted_shift_k_val; + } + + CUTLASS_DEVICE + ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + auto accum_ = accum[i]; + sum_ += accum_ * accum_; + } + + return sum_; + } + + CUTLASS_DEVICE + ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + auto accum_ = accum[i] - shift_k_val; + sum_ += accum_ * accum_; + } + + return sum_; + } + + CUTLASS_DEVICE + ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + sum_ += accum[i]; + } + + return sum_; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename ElementInputA0_, + typename LayoutInputA0_, + typename ElementInputB0_, + typename LayoutInputB0_, + typename ElementOutput_, + typename LayoutOutput_, + typename ElementCompute_, + typename EpilogueFunctorOp_, + typename ThreadblockShape_, + typename WarpShape_, + typename InstructionShape_, + int Stages0, + int Stages1, + bool IsShiftedVariance_ = false +> +class GemmLayernorm { +public: + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // + // Type definitions + // + + static bool const kInternalTranspose = cutlass::platform::is_same::value; + static bool const kIsShiftedVariance = IsShiftedVariance_; + + // These is mandatory layout. + using LayoutInputScaleBias = cutlass::layout::RowMajor; + + // These are mandatory data types. + using ElementLayernormCompute = float; + using ElementInputScaleBias = cutlass::half_t; + + // These are mandatory params required by mainloop fusion + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ArchTag = cutlass::arch::Sm80; + + // These are mandatory layouts and data types + // that are inheritated from pre-defined params + + using LayoutSumSqr = LayoutInputScaleBias; + using LayoutSum = LayoutInputScaleBias; + + using ElementMean = ElementInputScaleBias; + using ElementVariance = ElementInputScaleBias; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + using LayoutInputA0 = LayoutInputA0_; + using LayoutInputB0 = LayoutInputB0_; + using LayoutInputA1 = LayoutOutput_; + using LayoutInputB1 = LayoutOutput_; + using LayoutOutputC0 = LayoutOutput_; + using LayoutOutputC1 = LayoutOutput_; + + using ElementInputA0 = ElementInputA0_; + using ElementInputB0 = ElementInputB0_; + using ElementOutputC0 = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementInputB1 = ElementInputB0_; + + using ElementInputA1 = ElementOutputC0; + using ElementOutputC1 = ElementOutputC0; + + using EpilogueFunctorOp = EpilogueFunctorOp_; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorVariance = TensorRef; + using TensorMean = TensorRef; + + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + + static int const kStages0 = Stages0; + static int const kStages1 = Stages1; + + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + using MapArguments = cutlass::gemm::kernel::detail::MapArguments< + ElementInputA0, + LayoutInputA0, + cutlass::ComplexTransform::kNone, + 128 / cutlass::sizeof_bits::value, + ElementInputB0, + LayoutInputB0, + cutlass::ComplexTransform::kNone, + 128 / cutlass::sizeof_bits::value, + LayoutOutputC0, + kInternalTranspose + >; + + using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + MapArguments::kAlignmentA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + MapArguments::kAlignmentB, + ElementOutputC0, + typename MapArguments::LayoutC, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + SwizzleThreadBlock, + kStages0, + true, + typename cutlass::gemm::device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementInputA0, ElementInputB0, ElementOutputC0, ElementCompute>::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone + >::GemmKernel; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // Epilogue visitor + using EpilogueVisitor = kernel::EpilogueVisitorLayerNorm< + ThreadblockShape, + DefaultGemmKernel::kThreadCount, + typename DefaultGemmKernel::Epilogue::OutputTileIterator, + typename DefaultGemmKernel::Epilogue::AccumulatorFragmentIterator::AccumulatorTile, + ElementCompute, + ElementVariance, + ElementMean, + ElementLayernormCompute, + EpilogueFunctorOp, + kIsShiftedVariance + >; + + /// Epilogue + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, + typename DefaultGemmKernel::Epilogue + >::Epilogue; + + // GEMM + using GemmEpilogueFusion = gemm::kernel::GemmWithEpilogueVisitor< + typename DefaultGemmKernel::Mma, + Epilogue, + SwizzleThreadBlock + >; + + using ApplyFinalReductionKernel = kernel::ApplyFinalReduction< + ElementVariance, + ElementMean, + ElementLayernormCompute, + ElementOutputC0, + ThreadblockShape, + kIsShiftedVariance + >; + +using GemmMainloopFusion = typename cutlass::gemm::device::GemmLayernormMainloopFusion< + ElementInputA1, LayoutInputA1, + ElementInputB1, LayoutInputB1, + ElementInputScaleBias, LayoutInputScaleBias, + ElementOutputC1, LayoutOutputC1, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + SwizzleThreadBlock, + kStages1 +>; + +public: + + /// Arguments class + struct Arguments { + + typename GemmEpilogueFusion::Arguments gemm0; + typename GemmMainloopFusion::Arguments gemm1; + typename ApplyFinalReductionKernel::Arguments reduction; + cutlass::gemm::GemmCoord extend; + + // + // Methods + // + Arguments() { } + + Arguments( + cutlass::gemm::GemmCoord problem_size0, + cutlass::gemm::GemmCoord problem_size1, + ElementInputA0 * ptr_A, + ElementInputB0 * ptr_B, + ElementOutputC0 * ptr_C, + ElementOutputC0 * ptr_D, + ElementOutputC0 * ptr_E, + ElementOutputC0 * ptr_O, + int64_t ldm_A, + int64_t ldm_B, + int64_t ldm_C, + int64_t ldm_D, + int64_t ldm_E, + int64_t ldm_O, + typename EpilogueFunctorOp::Params linear_scaling, + TensorVariance ref_Variance_, + TensorMean ref_Mean_, + TensorVariance ref_Gamma_, + TensorMean ref_Beta_, + ElementOutputC0 *ptr_Shifted_K = nullptr + ): + gemm0( + cutlass::gemm::GemmUniversalMode::kGemm, + {kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ + kInternalTranspose ? problem_size0.m() : problem_size0.n(),\ + problem_size0.k()}, + {kInternalTranspose ? ptr_B : ptr_A, \ + kInternalTranspose ? ldm_B : ldm_A}, + {kInternalTranspose ? ptr_A : ptr_B, \ + kInternalTranspose ? ldm_A : ldm_B}, + typename EpilogueVisitor::Arguments( + linear_scaling, + {ptr_C, ldm_C}, + {ptr_D, ldm_D}, + ref_Variance_.data(), + ref_Mean_.data(), + ptr_Shifted_K + ) + ), + reduction( + MatrixCoord(kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ + kInternalTranspose ? problem_size0.m() : problem_size0.n()), + ref_Variance_, + ref_Mean_, + ptr_Shifted_K + ), + gemm1( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size1, + 1, + linear_scaling, + kInternalTranspose ? ptr_E : ptr_D, + kInternalTranspose ? ptr_D : ptr_E, + ref_Variance_.data(), + ref_Mean_.data(), + ref_Gamma_.data(), + ref_Beta_.data(), + ptr_O, + ptr_O, + problem_size1.m() * problem_size1.k(), + problem_size1.n() * problem_size1.k(), + problem_size1.n(), + problem_size1.n(), + problem_size1.k(), + problem_size1.k(), + problem_size1.m() * problem_size1.n(), + problem_size1.m() * problem_size1.n(), + kInternalTranspose ? ldm_E : ldm_D, + kInternalTranspose ? ldm_D : ldm_D, + ref_Variance_.layout().stride(0), + ref_Mean_.layout().stride(0), + ref_Gamma_.layout().stride(0), + ref_Beta_.layout().stride(0), + ldm_O, + ldm_O + ), + extend(problem_size0) + { + + } + }; + + struct Params { + + typename GemmEpilogueFusion::Params gemm0; + typename ApplyFinalReductionKernel::Params reduction; + MatrixCoord extend; + // + // Methods + // + Params() { } + + Params(Arguments const &args): + gemm0(args.gemm0), + reduction(args.reduction), + extend(MatrixCoord(args.extend.m(), args.extend.n())) + { + + } + }; + +public: + + // Gemm + + + // + // Methods + // + +private: + + Params params_; + GemmMainloopFusion gemm_fusion_op; + +public: + + /// Ctor + GemmLayernorm() { + + } + + /// Initialize + Status initialize(Arguments const &args) { + + params_ = Params(args); + cutlass::Status status; + size_t workspace_size = gemm_fusion_op.get_workspace_size(args.gemm1); + cutlass::device_memory::allocation workspace(workspace_size); + status = gemm_fusion_op.can_implement(args.gemm1); + CUTLASS_CHECK(status); + + status = gemm_fusion_op.initialize(args.gemm1, workspace.get()); + CUTLASS_CHECK(status); + + return cutlass::Status::kSuccess; + } + + /// Run + Status run(cudaStream_t stream) { + + // + // Launch the GEMM + layernorm kernel + // + + dim3 gemm_grid = SwizzleThreadBlock().get_grid_shape(params_.gemm0.grid_tiled_shape); + dim3 gemm_block(GemmEpilogueFusion::kThreadCount, 1, 1); + + int gemm_smem_size = int(sizeof(typename GemmEpilogueFusion::SharedStorage)); + + cutlass::Kernel<<>>(params_.gemm0); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the ApplyFinalReductionKernel + // + + // always performs reduction from leading dimension + int leading_dim_0 = kInternalTranspose ? params_.extend.row() : params_.extend.column(); + int leading_dim_1 = kInternalTranspose ? params_.extend.column() : params_.extend.row(); + + int thread_per_block = 128; + int block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; + if (block_per_row < 4) { + thread_per_block = 32; + block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; + } + + dim3 final_reduction_block(thread_per_block); + dim3 final_reduction_grid(block_per_row); + + Kernel<<< + final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream + >>>(params_.reduction); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the GEMM + mainloop fusion kernel + // + + cutlass::Status status = gemm_fusion_op(); + CUTLASS_CHECK(status); + + return cutlass::Status::kSuccess; + } + + /// Function call operator + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/layouts.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/layouts.h new file mode 100644 index 0000000000000000000000000000000000000000..d4d9ed316667496d03fca7cbe460f413cba50290 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/layouts.h @@ -0,0 +1,506 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines additional layout functions used in Permute GEMM example to simplify + computing reference permutations of 4/5D tensors when source data is column-major. +*/ +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/coord.h" +#include "cutlass/tensor_coord.h" + +namespace cutlass { +namespace layout { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D CWHN tensors. +class TensorCWHN { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [n, hn, whn] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHN(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHN( + typename Stride::Index stride_h, ///< number of elements between adjacent N coordinates + typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_c ///< number of elements between adjacent W coordinates + ): + stride_(make_Coord(stride_h, stride_w, stride_c)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorCWHN(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorCWHN packed(TensorCoord const &extent) { + return TensorCWHN( + make_Coord( + extent.n(), + extent.h() * extent.n(), + extent.w() * extent.h() * extent.n() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.n() + + LongIndex(stride_[0] * coord.h()) + + LongIndex(stride_[1] * coord.w()) + + LongIndex(stride_[2] * coord.c()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.n() > stride_[0]) + || (extent.h() * stride_[0] > stride_[1]) + || (extent.w() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.c() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D NHCW tensors. +class TensorNHCW { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [w, cw, hcw] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNHCW(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNHCW( + typename Stride::Index stride_c, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates + typename Stride::Index stride_n ///< number of elements between adjacent N coordinates + ): + stride_(make_Coord(stride_c, stride_h, stride_n)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNHCW(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorNHCW packed(TensorCoord const &extent) { + return TensorNHCW( + make_Coord( + extent.w(), + extent.c() * extent.w(), + extent.h() * extent.c() * extent.w() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.w() + + LongIndex(stride_[0] * coord.c()) + + LongIndex(stride_[1] * coord.h()) + + LongIndex(stride_[2] * coord.n()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.w() > stride_[0]) + || (extent.c() * stride_[0] > stride_[1]) + || (extent.h() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.n() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D NHCW tensors. +class TensorNCWH { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [h, wh, cwh] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNCWH(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNCWH( + typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_c, ///< number of elements between adjacent H coordinates + typename Stride::Index stride_n ///< number of elements between adjacent N coordinates + ): + stride_(make_Coord(stride_w, stride_c, stride_n)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNCWH(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorNCWH packed(TensorCoord const &extent) { + return TensorNCWH( + make_Coord( + extent.h(), + extent.w() * extent.h(), + extent.c() * extent.w() * extent.h() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.h() + + LongIndex(stride_[0] * coord.w()) + + LongIndex(stride_[1] * coord.c()) + + LongIndex(stride_[2] * coord.n()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.h() > stride_[0]) + || (extent.w() * stride_[0] > stride_[1]) + || (extent.c() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.n() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 5-D CWHDN tensors. +class TensorCWHDN { +public: + /// Logical rank of tensor + static int const kRank = 5; + + /// Rank of stride vector + static int const kStrideRank = 4; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, d, h, w, c) + using TensorCoord = Tensor5DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [n, dn, hdn, whdn] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHDN(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHDN( + typename Stride::Index n, + typename Stride::Index dn, + typename Stride::Index hdn, + typename Stride::Index whdn): + stride_(make_Coord(n, dn, hdn, whdn)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorCWHDN(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2]), + static_cast(stride[3])) + ) { } + + /// Helper returns a layout to a tightly packed CWHDN tensor. + CUTLASS_HOST_DEVICE + static TensorCWHDN packed(TensorCoord const &extent) { + return TensorCWHDN( + make_Coord( + extent.n(), + extent.d() * extent.n(), + extent.h() * extent.d() * extent.n(), + extent.w() * extent.h() * extent.d() * extent.n() + ) + ); + } + + /// Returns the offset of a coordinate (n, d, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.n() + + LongIndex(stride_[0] * coord.d()) + + LongIndex(stride_[1] * coord.h()) + + LongIndex(stride_[2] * coord.w()) + + LongIndex(stride_[3] * coord.c()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[3]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.n() > stride_[0]) + || (extent.d() * stride_[0] > stride_[1]) + || (extent.h() * stride_[1] > stride_[2]) + || (extent.w() * stride_[2] > stride_[3])) { + assert(0); + } + return extent.c() * stride_[3]; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layout +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/permute_info.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/permute_info.h new file mode 100644 index 0000000000000000000000000000000000000000..6baf635a8d5420ec736ea01eceba672b9df2e954 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/permute_info.h @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Contains additional metadata about layout permute functions used in the example. +*/ + +#include "cutlass/tensor_coord.h" +#include "cutlass/layout/permute.h" + +/// Additional permutation metadata to facilitate testing/printing +template +struct PermuteInfo; + +/// Specialization for default case (no permute). Other specializations must follow this template. +template<> +struct PermuteInfo { + + /// Whether this is a BMM or GEMM permutation (NoPermute can actually be either) + static bool constexpr kBatched = false; + + /// Minimal divisor for row extent + static int constexpr kRowFactor = 1; + + /// Minimum divisor for column extent + static int constexpr kColumnFactor = 1; + + /// Minimum divisor for batch size dimension + static int constexpr kBatchFactor = 1; + + /// Tensor layout used in permutation operation + using Layout = cutlass::layout::PackedVectorLayout; + + static std::string name() { + return "NoPermute"; + } + + /// User-friendly description of the permute operation + static std::string desc() { + return "no permutation"; + } + + /// Infer original higher-rank tensor shape from GEMM/BMM matrix extents. + /// For direct (output) permutations, must be a simple reshape of extent. + /// For inverse (input) permutations, must return shape *before* permute operation. + /// In case of NoPermute, simply use a linear (rank 1) view of the memory + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + return Layout::TensorCoord(extent.row() * extent.column() * batch_count); + } + + /// Compute the permuted higher-rank tensor shape from the original shape. + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return s; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = D1; + + using Layout = cutlass::layout::TensorNHWC; + + static std::string name() { + return "Tensor4DPermuteBMM0213<" + std::to_string(D1) + ">"; + } + + static std::string desc() { + return "batched GEMM permutation [0, 2, 1, 3]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count / D1; + int D2 = extent.row(); + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[2], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = D1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count; + int D2 = extent.row(); + int D3 = extent.column() / D1; + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = D1; + + using Layout = cutlass::layout::TensorNHCW; + + static std::string name() { + return "Tensor4DPermuteBMM0321<" + std::to_string(D1) + ">"; + } + + static std::string desc() { + return "batched GEMM permutation [0, 3, 2, 1]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count / D1; + int D2 = extent.row(); + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[3], s[2], s[1]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = D1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count; + int D2 = extent.row() / D1; + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = D1; + static int constexpr kColumnFactor = D2; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorNHWC; + + static std::string name() { + return "Tensor4DPermute0213<" + std::to_string(D1) + "," + std::to_string(D2) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [0, 2, 1, 3]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = extent.row() / D1; + int D3 = extent.column() / D2; + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[2], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = D2; + static int constexpr kColumnFactor = D1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = extent.row() / D2; + int D3 = extent.column() / D1; + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + using Layout = cutlass::layout::TensorCWHN; +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + using Layout = cutlass::layout::TensorCWHN; +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T1; + static int constexpr kColumnFactor = T2 * T3; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorNDHWC; + + static std::string name() { + return "Tensor5DPermute20314<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [2, 0, 3, 1, 4]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) + { + int const T0 = extent.row() / T1; + int const T4 = extent.column() / (T2 * T3); + return {T0, T1, T2, T3, T4}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) + { + return {s[2], s[0], s[3], s[1], s[4]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T2; + static int constexpr kColumnFactor = T1 * T3; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int const T0 = extent.row() / T2; + int const T4 = extent.column() / (T1 * T3); + return {T0, T1, T2, T3, T4}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T1; + static int constexpr kColumnFactor = T2 * T3; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorCWHDN; + + static std::string name() { + return "Tensor5DPermute02413<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [0, 2, 4, 1, 3]"; + } + + using Coord = cutlass::Tensor5DCoord; + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) + { + int const T0 = extent.row() / T1; + int const T4 = extent.column() / (T2 * T3); + return {T0, T1, T2, T3, T4}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) + { + return {s[0], s[2], s[4], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T2; + static int constexpr kColumnFactor = T1 * T3; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int const T0 = extent.row() / T2; + int const T4 = extent.column() / (T1 * T3); + return {T0, T1, T2, T3, T4}; + } +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/conv2d.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..cd11c74b44e899cec472d4cfb9b3a73a1e2dcd20 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/conv2d.py @@ -0,0 +1,177 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Basic example of using the CUTLASS Python interface to run a 2d convolution +""" + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import argparse +import numpy as np +import torch + +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.reference_model import Conv2dReferenceModule +from cutlass.backend.utils.device import device_cc + + +parser = argparse.ArgumentParser( + description=("Launch a 2d convolution kernel from Python. " + "See https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#convo-intro for notation.")) +parser.add_argument("--n", default=1, type=int, help="N dimension of the convolution") +parser.add_argument("--c", default=64, type=int, help="C dimension of the convolution") +parser.add_argument("--h", default=32, type=int, help="H dimension of the convolution") +parser.add_argument("--w", default=32, type=int, help="W dimension of the convolution") +parser.add_argument("--k", default=32, type=int, help="N dimension of the convolution") +parser.add_argument("--r", default=3, type=int, help="R dimension of the convolution") +parser.add_argument("--s", default=3, type=int, help="S dimension of the convolution") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +# Check that the device is of a sufficient compute capability +cc = device_cc() +assert cc >= 70, "The CUTLASS Python Conv2d example requires compute capability greater than or equal to 70." + +alignment = 1 + +np.random.seed(0) + +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.TensorNHWC, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 + +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 + instruction_shape = [16, 8, 16] + +math_inst = MathInstruction( + instruction_shape, + A.element, B.element, element_acc, + cutlass_bindings.OpClass.TensorOp, + MathOperation.multiply_add +) + +tile_description = TileDescription( + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape + math_inst +) + +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) + +operation = Conv2dOperation( + conv_kind=cutlass_bindings.conv.Operator.fprop, + iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, + arch=cc, tile_description=tile_description, + A=A, B=B, C=C, stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +# Compile the operation +pycutlass.compiler.add_module(operations) + +# Randomly initialize tensors + +problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(args.n, args.h, args.c, args.w), + cutlass_bindings.Tensor4DCoord(args.k, args.r, args.s, args.c), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), # Padding + cutlass_bindings.MatrixCoord(1, 1), # Strides + cutlass_bindings.MatrixCoord(1, 1), # Dilation + cutlass_bindings.conv.Mode.cross_correlation, + 1, # Split k slices + 1 # Groups +) + +tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size) +tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size) +tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size) + +tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5)) +tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5)) +tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) +tensor_D = torch.ones(size=(tensor_C_size,), dtype=torch.float32, device="cuda") + +alpha = 1. +beta = 0. + +arguments = Conv2dArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=operation.epilogue_type(alpha, beta) +) + +# Run the operation +operation.run(arguments) +arguments.sync() + +# Run the host reference module and compare to the CUTLASS result +reference = Conv2dReferenceModule(A, B, C, operation.conv_kind) +tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) + +try: + assert torch.equal(tensor_D, tensor_D_ref) +except: + assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2) + +print("Passed.") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/conv2d.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..d2df3ed8d1f43cebecdd14240e5b819e184885f7 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/conv2d.py @@ -0,0 +1,331 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import numpy as np +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +from cutlass.backend.conv2d_operation import * +from cutlass.backend.utils.reference_model import Conv2dReferenceModule +import torch.nn.functional as F + +import argparse + +# parse the arguments +parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from Python") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="Simt", type=str, + choices=["Simt", 'TensorOp'], + help='This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[ + 4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorNC32HW32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignment of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorC32RSK32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorNC32HW32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], + help='Data type of computation in the epilogue') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", + "HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4", + "StridedDgradHorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU") +# conv related +parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'], + help="The type of convolution: forward propagation (fprop), \ + gradient of activation (dgrad), gradient of weight (wgrad)") +parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"], + ) +parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str, + choices=["analytic", "optimized", "fixed_channels", "few_channels"], + help="This option describes iterator algorithm") + +# arguments +parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"], + help="Split K Mode. Serial is used for non-splitK or serial-splitK.\ + Parallel is used for parallel splitK.") +parser.add_argument('-k', '--split_k_slices', default=1, + type=int, help="Number of split-k partitions. (default 1)") +parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)") +parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)") +parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)") +parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)") +parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") +parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") + + +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +np.random.seed(0) + +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) +if (args.activation_function == "identity" + or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)): + # + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + +iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, args.iterator_algorithm) +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) +stride_support = getattr(StrideSupport, args.stride_support) +conv_kind = getattr(cutlass_bindings.conv.Operator, args.conv_kind) + +operation = Conv2dOperation( + conv_kind=conv_kind, iterator_algorithm=iterator_algorithm, + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, stride_support=stride_support, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation,] + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + if (args.activation_function == "identity"): + epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + else: + epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + reduction_operation = ReductionOperation( + shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment), + C=C, element_accumulator=element_acc, + element_compute=element_epilogue, + epilogue_functor=epilogue_functor_reduction, + count=C.alignment + ) + operations.append(reduction_operation) + +pycutlass.compiler.add_module(operations) + +problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]), + cutlass_bindings.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]), + cutlass_bindings.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]), + cutlass_bindings.MatrixCoord(args.stride[0], args.stride[1]), + cutlass_bindings.MatrixCoord(args.dilation[0], args.dilation[1]), + cutlass_bindings.conv.Mode.cross_correlation, + args.split_k_slices, 1 +) + + +# User-provide inputs +tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size( + conv_kind, problem_size +) +tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size( + conv_kind, problem_size +) +if args.bias: + tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_extent( + conv_kind, problem_size + ).at(3) +else: + tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size + ) + +tensor_D_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size + ) + +if args.element_a != "int8": + tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2) + +if args.element_b != "int8": + tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2) + +if args.element_c != "int8": + tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2) + +tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda") + +arguments = Conv2dArguments( + operation=operation, problem_size=problem_size, A=tensor_A, + B=tensor_B, C=tensor_C, D=tensor_D, + output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + split_k_mode=getattr(cutlass_bindings.conv.SplitKMode, args.split_k_mode), + split_k_slices=problem_size.split_k_slices +) + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size) + reduction_arguments = ReductionArguments( + reduction_operation, + problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], + partitions=problem_size.split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + bias = arguments.bias + ) + +operation.run(arguments) + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + reduction_operation.run(reduction_arguments) + reduction_arguments.sync() +else: + arguments.sync() + +reference_model = Conv2dReferenceModule(A, B, C, conv_kind) + +tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias) +if (args.activation_function != "identity"): + tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args)) + +try: + assert torch.equal(tensor_D, tensor_D_ref) +except: + assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2) +print("Passed.") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..3494fe53f5fe1066f15e73d7b1190ad0b4fe3fdb --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm.py @@ -0,0 +1,331 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import numpy as np +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +import cutlass_bindings +from bfloat16 import bfloat16 + +import argparse + + +# parse the arguments +parser = argparse.ArgumentParser(description="Launch CUTLASS GEMM kernels from Python: 'D = alpha * A * B + beta * C'") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="Simt", type=str, + choices=["Simt", 'TensorOp'], + help="This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM") +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignment of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"], + help="This option describes how thread blocks are scheduled on GPU") + +# Argument +parser.add_argument("-p", "--problem_size", + default=[128, 128, 128], nargs=3, type=int, + help="GEMM problem size M, N, K") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, + help="Scaling factor of A * B") +parser.add_argument("-beta", "--beta", default=0.0, type=float, + help="Scaling factor of C") +parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str, + choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"], + help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \ + GemmSplitKParallel is used for parallel splitK") +parser.add_argument('-k', '--split_k_slices', default=1, + type=int, help="Number of split-k partitions. (default 1)") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM") + +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) +pycutlass.compiler.nvcc() + +np.random.seed(0) + +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) +if (args.activation_function == "identity" + or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)): + # + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) + +operation = GemmOperationUniversal( + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +if args.gemm_mode == "GemmSplitKParallel": + if (args.activation_function == "identity"): + epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + else: + epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + + reduction_operation = ReductionOperation( + shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment), + C=C, element_accumulator=element_acc, + element_compute=element_epilogue, + epilogue_functor=epilogue_functor_reduction, + count=C.alignment + ) + operations.append(reduction_operation) + +pycutlass.compiler.add_module(operations) + +# User-provide inputs + +problem_size = cutlass_bindings.gemm.GemmCoord( + args.problem_size[0], args.problem_size[1], args.problem_size[2]) + +tensor_a_size = args.batch * problem_size.m() * problem_size.k() +if args.element_a != "int8": + if args.element_a == "bfloat16": + tensor_A = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) + ).astype(bfloat16) + else: + tensor_A = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) + ).astype(getattr(np, args.element_a)) +else: + tensor_A = np.random.uniform( + low=-2, high=2,size=(tensor_a_size,) + ).astype(getattr(np, args.element_a)) + +tensor_b_size = args.batch * problem_size.k() * problem_size.n() +if args.element_b != "int8": + if args.element_b == "bfloat16": + tensor_B = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) + ).astype(bfloat16) + else: + tensor_B = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) + ).astype(getattr(np, args.element_b)) +else: + tensor_B = np.random.uniform( + low=-2, high=2, size=(tensor_b_size,) + ).astype(getattr(np, args.element_b)) + +if args.element_c != "int8": + if args.bias: + if args.layout_c == "RowMajor": + tensor_c_size = args.batch * problem_size.n() + elif args.layout_c == "ColumnMajor": + tensor_c_size = args.batch * problem_size.m() + else: + raise ValueError(args.layout_c) + else: + tensor_c_size = args.batch * problem_size.m() * problem_size.n() + if args.element_c == "bfloat16": + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) + ).astype(bfloat16) + else: + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) + ).astype(getattr(np, args.element_c)) +else: + tensor_C = np.random.uniform( + low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + +tensor_D = np.zeros( + shape=(args.batch * problem_size.m() * problem_size.n(),) +).astype(getattr(np, args.element_c)) + +output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) + +arguments = GemmArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=output_op, + gemm_mode=getattr(cutlass_bindings.gemm.Mode, args.gemm_mode), + split_k_slices=args.split_k_slices, batch=args.batch +) + +if args.gemm_mode == "GemmSplitKParallel": + reduction_arguments = ReductionArguments( + operation=reduction_operation, + problem_size=[problem_size.m(), problem_size.n()], + partitions=args.split_k_slices, workspace=arguments.ptr_D, + destination=tensor_D, source=tensor_C, + output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + bias = arguments.bias + ) + +operation.run(arguments) + +if args.gemm_mode == "GemmSplitKParallel": + reduction_operation.run(reduction_arguments) + reduction_arguments.sync() +else: + arguments.sync() + +# run the host reference module +reference = ReferenceModule(A, B, C) +tensor_D_ref = reference.run( + tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch) + +tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) + +try: + assert np.array_equal(tensor_D, tensor_D_ref) +except: + assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5) +print("Passed.") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm_grouped.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm_grouped.py new file mode 100644 index 0000000000000000000000000000000000000000..a3319e60f9a688de49ca4b26a07692794ea71cb3 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm_grouped.py @@ -0,0 +1,298 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import numpy as np +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +import csv + +import argparse + +# parse the arguments +parser = argparse.ArgumentParser( + description="Launch CUTLASS GEMM Grouped kernels from Python") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="Simt", type=str, + choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[ + 4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignment of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU. \ + NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \ + This parameter is passed in at present to match the APIs of other kernels. The parameter \ + is unused within the kernel") +# precompute mode +parser.add_argument("-pm", "--precompute_mode", + default="Device", type=str, choices=["Host", "Device"], + help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)") +# arguments +parser.add_argument("-p", "--problem_size_dir", type=str, default="grouped_gemm_problem_size.csv", + help="path to the csv file contains the problem sizes") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") +parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") + +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +np.random.seed(0) + +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) +if args.activation_function == "identity": + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) +precompute_mode = getattr(SchedulerMode, args.precompute_mode) + +operation = GemmOperationGrouped( + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, + precompute_mode=precompute_mode +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +pycutlass.compiler.add_module([operation, ]) + +reference_module = ReferenceModule(A, B, C) + +# get problems +problem_sizes = [] +with open(args.problem_size_dir) as csv_file: + reader = csv.reader(csv_file) + for row in reader: + problem_sizes.append( + cutlass_bindings.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2])) + ) + +problem_count = len(problem_sizes) + +tensor_As = [] +tensor_Bs = [] +tensor_Cs = [] +tensor_Ds = [] +problem_sizes_coord = [] +tensor_D_refs = [] + +for problem_size in problem_sizes: + if args.element_a != "int8": + if args.element_a == "bfloat16": + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(bfloat16) + else: + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(getattr(np, args.element_a)) + else: + tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m() + * problem_size.k(),)).astype(getattr(np, args.element_a)) + + if args.element_b != "int8": + if args.element_b == "bfloat16": + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(bfloat16) + else: + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(getattr(np, args.element_b)) + else: + tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k() + * problem_size.n(),)).astype(getattr(np, args.element_b)) + + if args.element_c != "int8": + if args.bias: + if args.layout_c == "RowMajor": + c_size = problem_size.n() + elif args.layout_c == "ColumnMajor": + c_size = problem_size.m() + else: + raise ValueError(args.layout_c) + else: + c_size = problem_size.m() * problem_size.n() + if args.element_c == "bfloat16": + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) + ).astype(bfloat16) + else: + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) + ).astype(getattr(np, args.element_c)) + else: + tensor_C = np.random.uniform( + low=-2, high=2, size=(problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + tensor_D = np.zeros( + shape=(problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + + tensor_As.append(tensor_A) + tensor_Bs.append(tensor_B) + tensor_Cs.append(tensor_C) + tensor_Ds.append(tensor_D) + tensor_D_ref = reference_module.run( + tensor_A, tensor_B, tensor_C, problem_size, + args.alpha, args.beta, args.bias) + tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) + tensor_D_refs.append(tensor_D_ref) + problem_sizes_coord.append(problem_size) + +arguments = GemmGroupedArguments( + operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, + output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) +) + +operation.run(arguments) + +arguments.sync() + +for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs): + try: + assert np.array_equal(tensor_d, tensor_d_ref) + except: + assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5) + +print("Passed.") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd4113a6e9a8e524927502f5bbb9b2d1c30820e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm.py @@ -0,0 +1,153 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Basic example of using the CUTLASS Python interface to run a GEMM +""" + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import argparse +import numpy as np + +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc + + +parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'") +parser.add_argument("--m", default=128, type=int, help="M dimension of the GEMM") +parser.add_argument("--n", default=128, type=int, help="N dimension of the GEMM") +parser.add_argument("--k", default=128, type=int, help="K dimension of the GEMM") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +# Check that the device is of a sufficient compute capability +cc = device_cc() +assert cc >= 70, "The CUTLASS Python GEMM example requires compute capability greater than or equal to 70." + +alignment = 8 +assert args.m % alignment == 0, "M dimension of size {} is not divisible by alignment of {}".format(args.m, alignment) +assert args.n % alignment == 0, "N dimension of size {} is not divisible by alignment of {}".format(args.n, alignment) +assert args.k % alignment == 0, "K dimension of size {} is not divisible by alignment of {}".format(args.k, alignment) + +np.random.seed(0) + +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 + +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 + instruction_shape = [16, 8, 16] + +math_inst = MathInstruction( + instruction_shape, + A.element, B.element, element_acc, + cutlass_bindings.OpClass.TensorOp, + MathOperation.multiply_add +) + +tile_description = TileDescription( + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape + math_inst +) + +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) + +operation = GemmOperationUniversal( + arch=cc, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +# Compile the operation +pycutlass.compiler.add_module(operations) + +# Randomly initialize tensors +tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.k,))).astype(np.float16) +tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.k * args.n,))).astype(np.float16) +tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.n,))).astype(np.float32) +tensor_D = np.zeros(shape=(args.m * args.n,)).astype(np.float32) + +problem_size = cutlass_bindings.gemm.GemmCoord(args.m, args.n, args.k) +alpha = 1. +beta = 0. + +arguments = GemmArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=operation.epilogue_type(alpha, beta)) + +# Run the operation +operation.run(arguments) +arguments.sync() + +# Run the host reference module and compare to the CUTLASS result +reference = ReferenceModule(A, B, C) +tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) + +try: + assert np.array_equal(tensor_D, tensor_D_ref) +except: + assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5) + +print("Passed.") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm_grouped.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm_grouped.py new file mode 100644 index 0000000000000000000000000000000000000000..508b0894827315003e46650a55371b4c78d83812 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm_grouped.py @@ -0,0 +1,172 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Basic example of using the CUTLASS Python interface to run a grouped GEMM +""" + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import argparse +import numpy as np + +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc + + +parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +# Check that the device is of a sufficient compute capability +cc = device_cc() +assert cc >= 70, "The CUTLASS Python grouped GEMM example requires compute capability greater than or equal to 70." + +np.random.seed(0) + +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +alignment = 1 +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 + +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 + instruction_shape = [16, 8, 16] + +math_inst = MathInstruction( + instruction_shape, + A.element, B.element, element_acc, + cutlass_bindings.OpClass.TensorOp, + MathOperation.multiply_add +) + +tile_description = TileDescription( + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape + math_inst +) + +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) + +operation = GemmOperationGrouped( + arch=cc, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, + precompute_mode=SchedulerMode.Device) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +# Compile the operation +pycutlass.compiler.add_module(operations) + +# Initialize tensors for each problem in the group +problem_sizes = [ + cutlass_bindings.gemm.GemmCoord(128, 128, 64), + cutlass_bindings.gemm.GemmCoord(512, 256, 128) +] +problem_count = len(problem_sizes) + +alpha = 1. +beta = 0. + +tensor_As = [] +tensor_Bs = [] +tensor_Cs = [] +tensor_Ds = [] +tensor_D_refs = [] + +reference = ReferenceModule(A, B, C) + +for problem_size in problem_sizes: + # Randomly initialize tensors + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * k,))).astype(np.float16) + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(k * n,))).astype(np.float16) + tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * n,))).astype(np.float32) + tensor_D = np.zeros(shape=(m * n,)).astype(np.float32) + + tensor_As.append(tensor_A) + tensor_Bs.append(tensor_B) + tensor_Cs.append(tensor_C) + tensor_Ds.append(tensor_D) + + # Run the reference GEMM + tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) + tensor_D_refs.append(tensor_D_ref) + +arguments = GemmGroupedArguments( + operation, problem_sizes, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, + output_op=operation.epilogue_type(alpha, beta) +) + +# Run the operation +operation.run(arguments) +arguments.sync() + +# Compare the CUTLASS result to the host reference result +for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs): + try: + assert np.array_equal(tensor_d, tensor_d_ref) + except: + assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5) + +print("Passed.") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/debug_utils.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/debug_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a22f12b711bbc9a6dc1a6cc6020dd6df6e3ffe0b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/debug_utils.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Debugging functions +//////////////////////////////////////////////////////////////////////////////// +// Nans & inf detection +#define NANCHECK(frag) \ + { \ + for (size_t _i = 0; _i < frag.size(); ++_i) { \ + assert(std::isfinite(float(frag[_i]))); \ + assert(!std::isnan(float(frag[_i]))); \ + } \ + } + +// Print on the first thread of the first block +#if 1 +#define PRINT_WARP_ID 0 +#define PRINT_LANE_ID 0 +#define PRINT_B0_T0(msg, ...) \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ + threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_T0(msg, ...) \ + if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_TX_LX(msg, ...) \ + for (int bx = 0; bx < gridDim.x; ++bx) { \ + for (int by = 0; by < gridDim.y; ++by) { \ + for (int bz = 0; bz < gridDim.z; ++bz) { \ + for (int tx = 0; tx < blockDim.x; ++tx) { \ + for (int ty = 0; ty < blockDim.y; ++ty) { \ + for (int tz = 0; tz < blockDim.z; ++tz) { \ + __syncthreads(); \ + if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \ + threadIdx.x == tx && threadIdx.y == ty && \ + threadIdx.z == tz) { \ + printf( \ + "[%d,%d,%d][%d,%d,%d]" msg "\n", \ + bx, \ + by, \ + bz, \ + tx, \ + ty, \ + tz, \ + ##__VA_ARGS__); \ + } \ + } \ + } \ + } \ + } \ + } \ + } +#else +#define PRINT_B0_T0 +#define PRINT_TX_LX +#endif + +struct __string_view { + char const* data; + std::size_t size; +}; +#if __cplusplus >= 201402L +template +constexpr __string_view __get_type_name() { + char const* p = __PRETTY_FUNCTION__; + while (*p++ != '=') + ; + for (; *p == ' '; ++p) + ; + char const* p2 = p; + int count = 1; + for (;; ++p2) { + switch (*p2) { + case '[': + ++count; + break; + case ']': + --count; + if (!count) + return {p, std::size_t(p2 - p)}; + } + } + return {}; +} +#else +template +constexpr __string_view __get_type_name() { + return {"unsupported", 11}; +} +#endif + +// Print a given array +#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ + PRINT_B0_T0( \ + "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ + name, \ + int(start), \ + int(start + 8), \ + float(accum[start + 0]), \ + float(accum[start + 1]), \ + float(accum[start + 2]), \ + float(accum[start + 3]), \ + float(accum[start + 4]), \ + float(accum[start + 5]), \ + float(accum[start + 6]), \ + float(accum[start + 7])); +#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0) +#define PRINT_FRAG_T0_L0(name, frag) \ + { \ + auto typeStr = __get_type_name(); \ + PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \ + for (size_t _start = 0; _start < frag.size(); _start += 8) { \ + PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ + } \ + /*__syncthreads(); \ + NANCHECK(frag); */ \ + } +#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ + { \ + PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \ + for (int _start = 0; _start < length; _start += incr) { \ + PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ + } \ + } +#define PRINT_ARRAY_T0_L0(name, array, length) \ + PRINT_ARRAY_T0_L0_INCR(name, array, length, 8) + +// Print a 4x4 matrix +#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ + PRINT_B0_T0( \ + "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \ + name, \ + int(start_x), \ + int(start_x + 4), \ + int(start_y), \ + int(start_y + 4), \ + float(ref.at({start_x + 0, start_y + 0})), \ + float(ref.at({start_x + 0, start_y + 1})), \ + float(ref.at({start_x + 0, start_y + 2})), \ + float(ref.at({start_x + 0, start_y + 3})), \ + float(ref.at({start_x + 1, start_y + 0})), \ + float(ref.at({start_x + 1, start_y + 1})), \ + float(ref.at({start_x + 1, start_y + 2})), \ + float(ref.at({start_x + 1, start_y + 3})), \ + float(ref.at({start_x + 2, start_y + 0})), \ + float(ref.at({start_x + 2, start_y + 1})), \ + float(ref.at({start_x + 2, start_y + 2})), \ + float(ref.at({start_x + 2, start_y + 3})), \ + float(ref.at({start_x + 3, start_y + 0})), \ + float(ref.at({start_x + 3, start_y + 1})), \ + float(ref.at({start_x + 3, start_y + 2})), \ + float(ref.at({start_x + 3, start_y + 3}))); +#define PRINT_TENSOR4x4_T0_L0(name, ref) \ + PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) + +#define PRINT_PROBLEM_SIZE(name, ps) \ + PRINT_B0_T0( \ + "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ + name, \ + int(ps.m()), \ + int(ps.n()), \ + int(ps.k())) + +template +CUTLASS_DEVICE void print_warp_accum( + AccumT accum, + LaneOffsetT lane_offset, + int32_t num_rows, + int32_t num_cols) { + bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + if (col % 32 == 0) { + if (is_main) { + printf("\nmat[%3d, %3d:%3d]", row, col, col + 32); + } + __syncthreads(); + } + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (row == accum_m && col == accum_n && + (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) { + printf(" %6.1f", float(accum[idx])); + } + }, + [&](int accum_m) {}); + __syncthreads(); + } + if (is_main) { + printf("\n"); + } + } +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..14604f10c368ba1132100f20dd073e45d7afcf2d --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h @@ -0,0 +1,299 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "fmha_grouped.h" +#include "gemm_kernel_utils.h" +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock, + int kKeysPerBlock, + int kMaxK = (int)cutlass::platform::numeric_limits::max(), + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly + > +struct DefaultFMHAGrouped { + using scalar_t = scalar_t_; + using accum_t = float; + using output_t = scalar_t; + + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + + using ArchTag = ArchTag_; + static bool const kIsAligned = isAligned_; + static bool const kSingleValueIteration = kMaxK <= kKeysPerBlock; + static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; + static int const kWarpSize = 32; + static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize); + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + + using GemmType = gemm_kernel_utils::DefaultGemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = scalar_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator + >; + + static int const kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + Operator + >::DefaultMma; + + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; + using Mma = typename cutlass::platform::conditional< + kSingleValueIteration, + typename MakeCustomMma::Mma, + DefaultThreadblockMma>::type; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + ElementAccumulator, + kWarpSize>::Iterator; + + static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, ""); + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /* + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + + using GemmType = typename MM0::GemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator + >; + + static int const kAlignmentA = DefaultConfig::kAlignmentA; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = typename MM0::ThreadblockShape; + using WarpShape = typename MM0::WarpShape; + using InstructionShape = typename MM0::InstructionShape; + + using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using ThreadblockSwizzle = void; // Swizzling is unused + static bool const kSplitKSerial = false; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + kSplitKSerial, + Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Policy::Operator::InstructionShape, + typename DefaultGemm::Mma::Policy::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert(WarpCount::kCount == kNumWarpsPerBlock, ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + }; + +/// Define the kernel in terms of the default kernel + using FMHAKernel = kernel::FMHAGrouped< + MM0, + MM1, + scalar_t, + accum_t, + output_t, + output_accum_t, + kSingleValueIteration, + GroupScheduleMode_ + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..84f89a5ff4d7885d9d7bbec70cbf1913bc9890d1 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h @@ -0,0 +1,622 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + File copied from "cutlass/epilogue/threadblock/epilogue.h" + then modified to: + (1) load 2 source fragments at the same time (pipelining) + (2) support reading from a different dtype + (3) pass the row id to the OutputOp if it takes it + (see MemoryEfficientAttentionNormalize) + Note that in general the fragment passed to the OutputOp could + span multiple rows but it does not happen with the configurations we have +*/ +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct ApplyEpilogueOp { + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentOutput const& source) { + return output_op(accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: + ///< gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting + ///< accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing + ///< accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading + ///< from SMEM + typename OutputOp_, ///< Output operator + typename Padding_, ///< Padding added to SMEM allocation to avoid bank + ///< conflicts (concept: MatrixShape) + int FragmentsPerPartition = + 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is + ///< large + (!IsEpilogueFunctorHeavy::value), + typename OutputTileSourceIterator_ = + OutputTileIterator_ ///< Tile iterator reading tensors + > +class EpiloguePipelined : public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + public: + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using OutputTileSourceIterator = OutputTileSourceIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + using ElementSource = typename OutputTileSourceIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + using SourceAccessType = Array< + typename OutputTileSourceIterator::Element, + OutputTileSourceIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array< + typename WarpTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + public: + static_assert( + OutputTileSourceIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between input tile and output tile iterator (kElements)"); + static_assert( + OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, + "Mismatch between input tile and output tile iterator (kIterations)"); + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert( + OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert( + !(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + public: + /// Constructor + CUTLASS_DEVICE + EpiloguePipelined( + typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) {} + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator) { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_( + output_op, destination_iterator, accumulators, source_iterator); + } + } + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators) { ///< Complete warp-level accumulator tile + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + + private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper( + iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert( + kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators ///< Complete warp-level accumulator tile + ) { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll( \ + IterationsUnroll \ + ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator ///< Threadblock tile coordinate in GEMM (in units of + ///< threadblock tiles) + ) { + typename OutputTileSourceIterator::Fragment source_fragment[2]; + + source_fragment[0].clear(); + source_iterator.load(source_fragment[0]); + ++source_iterator; + source_fragment[1].clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + if (iter > 0) { + __syncthreads(); + } + // + // Load the source for next iteration (pipelining) + // + + if (iter + 1 < OutputTileIterator::kIterations) { + source_iterator.load(source_fragment[(iter + 1) % 2]); + } + ++source_iterator; + acc2smem_source_needed< + cutlass::make_index_sequence>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment[iter % 2]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileSourceIterator::Fragment const& source_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + SourceAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i], + source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i]); + } + } + + // This should be constexpr, but it's only supported on c++14 + static int CUTLASS_HOST_DEVICE getRowOffset(int i) { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { + return row_offset; + } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h new file mode 100644 index 0000000000000000000000000000000000000000..411b5574ec715b053a341e8d5f5dfadff3511555 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h @@ -0,0 +1,252 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + This is a copy of cutlass/epilogue/threadblock/epilogue.h that can + handle "row_id" as a first argument, as uses it to get the corresponding + `m_prime` / `s_prime` to rescale the output. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "epilogue_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template < + typename ElementOutput_, ///< Data type used to store tensors + typename ElementSource_, //< Data type for source (usually matches + //`ElementOutput`) + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { + public: + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + + private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize( + FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return !isFirst; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + int row, + FragmentAccumulator const& accumulator, + FragmentSource const& source) const { + assert(!isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) + const { + assert(isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + + intermediate = mul_accumulator( + alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +} // namespace thread + +namespace threadblock { +template < + typename EO, + typename ES, + int Count, + typename EA, + typename EC, + bool F, + bool L, + typename FAB, + FloatRoundStyle R> +struct ApplyEpilogueOp> { + using Op = thread:: + MemoryEfficientAttentionNormalize; + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentSource const& source) { + return output_op(row_id, accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(row_id, accum); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h new file mode 100644 index 0000000000000000000000000000000000000000..b110abecedaa3324fe360bc919da2ed7a07baba3 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h @@ -0,0 +1,174 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayExponential { + CUTLASS_HOST_DEVICE + Array operator()( + Array const& input) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = expf(input[i]); + } + + return result; + } +}; + +template +struct ArrayExponential { + CUTLASS_DEVICE + Array operator()( + Array const& input) const { + Array result; + + int const kVectorCount = ElementsPerAccess / 2; + + __half2 const* input_ptr = + reinterpret_cast<__half2 const*>(input.raw_data()); + __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = h2exp(input_ptr[i]); + } + + return result; + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies: +/// output <- (input - lse).exp() +template < + typename ElementOutput_, // output + typename ElementLSE_, // accumulator from LSE + typename ElementAccumulator_, // accumulator from matmul + typename ElementCompute_, // intermediate compute (and exp calculation) + int ElementsPerAccess> +class ApplyLogSumExp { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementLSE = ElementLSE_; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + static const ScaleType::Kind kScale = + cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentLSE = Array; + using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h + + public: + // + // Methods + // + + CUTLASS_HOST_DEVICE + ApplyLogSumExp() {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& AB, + FragmentLSE const& scale_unused, + // bias used as LSE + FragmentLSE const& bias) const { + FragmentCompute frag_AB = NumericArrayConverter< + ElementCompute, + ElementAccumulator, + kElementsPerAccess>()(AB); + FragmentCompute frag_lse_compute = + NumericArrayConverter()( + bias); + FragmentCompute frag_compute; + + minus minus_lse; + detail::ArrayExponential apply_exp; + frag_compute = minus_lse(frag_AB, frag_lse_compute); + frag_compute = apply_exp(frag_compute); + + return NumericArrayConverter< + ElementOutput, + ElementCompute, + kElementsPerAccess>()(frag_compute); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_backward_test.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_backward_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc25462ac4251158d2262fb3f6b61ea0228d09f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_backward_test.py @@ -0,0 +1,232 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import argparse +import torch +import sys +import os +from piped_subprocess import PipedSubprocess, TORCH_DTYPE_NAME +import math + + +parser = argparse.ArgumentParser() +parser.add_argument("example_exe", type=str, help="Path to the 41_fused_multi_head_attention_backward executable") +args = parser.parse_args() + +torch.manual_seed(0) +dtype = torch.float16 +B, Mq, Mkv, H, K, Kv = 2, 1024, 1024, 5, 128, 128 +causal = True +repeat_count = 100 + +ATOL = { + torch.float: 5e-4, + torch.half: 9.5e-2, + torch.bfloat16: 7e-1, +}[dtype] + +RTOL = { + torch.float: 1e-4, + torch.half: 2e-2, + torch.bfloat16: 1e-1, +}[dtype] + + +assert not (causal and Mq < Mkv), "causal only supports seqlenK <= seqlenQ" + +fmha_bw_binary = args.example_exe +if not os.path.isfile(fmha_bw_binary): + print(f"""No such file: `{fmha_bw_binary}`\nDid you forget to run "make 41_fused_multi_head_attention"?""") + sys.exit(1) + +def create_lower_triangular_mask(): + return torch.triu(torch.full( # type: ignore + [1, Mq, Mkv], + dtype=dtype, + fill_value=float("-inf"), + ), diagonal=1) + +def ref_mha_bmk(q, k, v, mask): + # Multi-head attention with inputs/outputs in BMK format + q = q.float() + k = k.float() + v = v.float() + + q = q * (1 / q.shape[-1] ** 0.5) + attn = q @ k.transpose(-2, -1) + if mask is not None: + attn += mask + attn_max = attn.max(-1, True).values + attn_norm = (attn - attn_max).exp().sum(-1, True) + attn = attn.softmax(-1) + lse = attn_max + attn_norm.log() + lse = lse.squeeze(2) + return attn @ v, lse + + +def bmhk2bmk(t): + return t.permute((0, 2, 1, 3)).reshape( + [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] + ) + +def ref_mha_bmhk(q, k, v, mask): + # Multi-head attention with inputs/outputs in BMHK format + assert q.ndim == 4 + + out, lse = ref_mha_bmk(bmhk2bmk(q), bmhk2bmk(k), bmhk2bmk(v), mask=mask) + out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) + return out.permute((0, 2, 1, 3)), lse.reshape([q.shape[0], q.shape[2], q.shape[1]]) + +def ref_mha_bw_bmhk(q, k, v, mask, lse, out, grad_out, delta): + lse = lse[:, :, :q.shape[1]] #BMH, unpad Q dimension + delta = delta.reshape([-1, delta.shape[-1], 1]) + + # bmhk -> bmk + q, k, v, out, grad_out = [bmhk2bmk(x).float() for x in (q, k, v, out, grad_out)] + + attn_T = k @ q.transpose(-2, -1) + if mask is not None: + attn_T += mask.transpose(-2, -1) + attn_T = attn_T * (1 / q.shape[-1] ** 0.5) + attn_T = attn_T - lse.reshape([-1, 1, lse.shape[-1]]) + attn_T = attn_T.exp() + + grad_v = attn_T @ grad_out + + dov = grad_out @ v.transpose(-2, -1) + tmp = (dov - delta) * attn_T.transpose(-2, -1) + tmp = tmp / (q.shape[-1] ** 0.5) + + grad_q = tmp @ k + grad_k = tmp.transpose(-2, -1) @ q + + return [x.reshape([B, H, x.shape[1], x.shape[-1]]).permute([0, 2, 1, 3]) for x in [grad_q, grad_k, grad_v]] + + +print("initializing tensors...") +query = torch.randn([B, Mq, H, K], dtype=dtype) +key = 3 * torch.randn([B, Mkv, H, K], dtype=dtype) +value = 3 * torch.randn([B, Mkv, H, Kv], dtype=dtype) +mask = create_lower_triangular_mask() if causal else None + +# let PyTorch compute gradients +query.requires_grad_(True) +key.requires_grad_(True) +value.requires_grad_(True) + +print("computing fw...") +out, lse = ref_mha_bmhk(query, key, value, mask=mask) +out = out.to(dtype).contiguous() +grad_out = 3 * torch.randn([B, Mq, H, Kv], dtype=dtype) + +print("computing bw with autograd...") +out.backward(grad_out) +scale = (1 / query.shape[-1] ** 0.5) + + +# Additional data needed by the kernel +delta = (grad_out.float() * out.float()).sum(-1).transpose(-2, -1).contiguous() +pad_amount = (32 - (lse.shape[2] % 32)) % 32 +lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf) + +print("computing bw with reference implem...") +gQr, gKr, gVr = ref_mha_bw_bmhk(query, key, value, mask, lse, out, grad_out, delta) + +with PipedSubprocess(fmha_bw_binary) as bw_kernel: + # Send kernel arguments + bw_kernel.write( + TORCH_DTYPE_NAME[query.dtype], + "scale", scale, + "head_dim", K, + "head_dim_value", Kv, + "num_queries", Mq, + "num_keys", Mkv, + "num_heads", H, + "custom_mask_type", (1 if causal else 0), + "num_batches", B, + "repeat_count", repeat_count, + "num_splits_key", (Mkv // 128), + ) + bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"]) + bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"]) + bw_kernel.writeTensor(value, "value", ["v_strideB", "v_strideM", "v_strideH"]) + bw_kernel.writeTensor(lse, "logsumexp", ["lse_strideB", "lse_strideH"]) + bw_kernel.writeTensor(out, "output", ["o_strideB", "o_strideM", "o_strideH"]) + bw_kernel.writeTensor(grad_out, "grad_output", ["gO_strideB", "gO_strideM", "gO_strideH"]) + bw_kernel.writeTensor(delta, "delta", ["delta_strideB", "delta_strideH"]) + + if bw_kernel.read() != "OK": + print("Got unexpected output") + print(bw_kernel.subp.communicate()[0]) + sys.exit(0) + + # Read kernel output + gQ = bw_kernel.readTensor("grad_query", ["gQ_strideB", "gQ_strideM", "gQ_strideH"], query.shape).float() + gK = bw_kernel.readTensor("grad_key", ["gK_strideB", "gK_strideM", "gK_strideH"], key.shape).float() + gV = bw_kernel.readTensor("grad_value", ["gV_strideB", "gV_strideM", "gV_strideH"], value.shape).float() + runtime_ms = float(bw_kernel.readNamed("runtime_ms")) + +float_ops = B * H * sum([ + # att = Q @ K.transpose + Mq * Mkv * K * 2, + # att @ dO + Mkv * Mq * Kv * 2, + # dov = dO @ V + Mq * Kv * Mkv * 2, + # dov @ K + Mq * K * Mkv * 2, + # dov @ Q + Mq * K * Mkv * 2, +]) +if causal: + float_ops //= 2 + +print(f""" +Fused multi-head attention - backward + batch_size={B} + num_queries={Mq} + num_keys={Mkv} + num_heads={H} + head_dim={K} + head_dim_value={Kv} + + Correctness: + grad_query: {"PASS" if torch.allclose(gQ, gQr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gQ - gQr).abs().max()}) + grad_key: {"PASS" if torch.allclose(gK, gKr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gK - gKr).abs().max()}) + grad_value: {"PASS" if torch.allclose(gV, gVr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gV - gVr).abs().max()}) + (atol={ATOL} / rtol={RTOL}) + Runtime: {runtime_ms}ms ({(float_ops / (1024 ** 4)) / (runtime_ms / 1000):.4f} TFlops) +""") + +assert torch.allclose(query.grad.float(), gQr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!" +assert torch.allclose(key.grad.float(), gKr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!" +assert torch.allclose(value.grad.float(), gVr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!" diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..afc25e43401a7bd75bba62f1e06b33208e11a181 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h @@ -0,0 +1,1023 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped FMHA kernel +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" + +#include "fmha_grouped_problem_visitor.h" +#include "gemm_kernel_utils.h" +#include "gemm/mma_accum_lambda_iterator.h" +#include "epilogue/epilogue_rescale_output.h" + + +namespace { + static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename MM0_, ///! Structure for computing P = Q @ K + typename MM1_, ///! Structure for computing O = P @ V + typename scalar_t_, + typename accum_t_, + typename output_t_, + typename output_accum_t_, + bool kKeepOutputInRF, ///! Whether the intermediate output from MM0_ should be kept in the register file + GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform +> +struct FMHAGrouped { +public: + using MM0 = MM0_; + using MM1 = MM1_; + + using scalar_t = scalar_t_; + using accum_t = accum_t_; + using output_t = output_t_; + using output_accum_t = output_accum_t_; + + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + // Parameters to satisfy BaseGrouped + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = accum_t; + using LayoutA = typename MM0::LayoutA; + using LayoutB = typename MM0::ElementB; + using LayoutC = typename MM1::ElementC; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static int const kAlignmentA = MM0::kAlignmentA; + static int const kAlignmentB = MM0::kAlignmentB; + static int const kAlignmentC = 1; + using Mma = typename MM1::Mma; + using EpilogueOutputOp = typename MM1::EpilogueOutputOp; + using ThreadblockSwizzle = void; + using Operator = typename MM1::Operator; + using WarpShape = typename MM1::WarpShape; + using InstructionShape = typename MM1::InstructionShape; + + using ElementQ = scalar_t; + using ElementK = scalar_t; + using ElementP = accum_t; + using ElementV = scalar_t; + using ElementO = output_t; + using ElementOAccum = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutQ = typename MM0::LayoutA; + using LayoutK = typename MM0::LayoutB; + using LayoutP = typename MM0::LayoutC; + using LayoutV = typename MM1::LayoutB; + using LayoutO = typename MM1::LayoutC; + + static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 && + cutlass::sizeof_bits::value == 16); + + static int const kAlignmentQ = MM0::kAlignmentA; + static int const kAlignmentK = MM0::kAlignmentB; + static int const kAlignmentV = 1; + + using ThreadblockShape = typename MM0::ThreadblockShape; + + static int const kQueriesPerBlock = ThreadblockShape::kM; + static int const kKeysPerBlock = ThreadblockShape::kN; + + static constexpr bool kSupportsDropout = false; + static constexpr bool kSupportsBias = false; + + /// Warp count (concept: GemmShape) + using WarpCount = typename MM1::WarpCount; + static int const kThreadsPerWarp = 32; + static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; + + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp); + + using ProblemVisitor = FMHAGroupedProblemVisitor< + ThreadblockShape, + kGroupScheduleMode, + kThreadCount, + kThreadCount>; + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord *problem_sizes0{nullptr}; + GemmCoord *problem_sizes1{nullptr}; + + int problem_count{0}; + int threadblock_count{0}; + + ElementQ ** ptr_Q{nullptr}; + ElementK ** ptr_K{nullptr}; + ElementP ** ptr_P{nullptr}; + ElementV ** ptr_V{nullptr}; + ElementO ** ptr_O{nullptr}; + ElementOAccum ** ptr_O_accum{nullptr}; + + typename LayoutQ::Stride::LongIndex *ldq{nullptr}; + typename LayoutK::Stride::LongIndex *ldk{nullptr}; + typename LayoutP::Stride::LongIndex *ldv{nullptr}; + typename LayoutO::Stride::LongIndex *ldo{nullptr}; + + // Whether causal masking is to be performed + bool causal{false}; + + // Scale + ElementAccumulator scale{0}; + + // Only used by device-level operator + GemmCoord *host_problem_sizes{nullptr}; + + // + // Methods + // + + /// Default ctor + Arguments() = default; + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord *problem_sizes0, + GemmCoord *problem_sizes1, + int problem_count, + int threadblock_count, + ElementQ ** ptr_Q, + ElementK ** ptr_K, + ElementP ** ptr_P, + ElementV ** ptr_V, + ElementO ** ptr_O, + ElementOAccum ** ptr_O_accum, + typename LayoutQ::Stride::LongIndex *ldq, + typename LayoutK::Stride::LongIndex *ldk, + typename LayoutP::Stride::LongIndex *ldp, + typename LayoutV::Stride::LongIndex *ldv, + typename LayoutO::Stride::LongIndex *ldo, + bool causal, + ElementAccumulator scale, + GemmCoord *host_problem_sizes=nullptr + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + threadblock_count(threadblock_count), + ptr_Q(ptr_Q), + ptr_K(ptr_K), + ptr_P(ptr_P), + ptr_V(ptr_V), + ptr_O(ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O), + ldq(ldq), + ldk(ldk), + ldv(ldv), + ldo(ldo), + causal(causal), + scale(scale), + host_problem_sizes(host_problem_sizes) + { + + } + + bool __host__ check_supported() { + CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ); + CHECK_ALIGNED_PTR(ptr_K, kAlignmentK); + CHECK_ALIGNED_PTR(ptr_V, kAlignmentV); + XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned"); + XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned"); + XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned"); + return true; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + ElementQ ** ptr_Q; + ElementK ** ptr_K; + ElementP ** ptr_P; + ElementV ** ptr_V; + ElementO ** ptr_O; + ElementOAccum ** ptr_O_accum; + + typename LayoutQ::Stride::LongIndex *ldq; + typename LayoutK::Stride::LongIndex *ldk; + typename LayoutP::Stride::LongIndex *ldv; + typename LayoutO::Stride::LongIndex *ldo; + + ElementAccumulator scale; + bool causal; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_O_accum(nullptr), + ldq(nullptr), + ldk(nullptr), + ldv(nullptr), + ldo(nullptr), + causal(false), + scale(0) + { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0): + problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + ptr_Q(args.ptr_Q), + ptr_K(args.ptr_K), + ptr_P(args.ptr_P), + ptr_V(args.ptr_V), + ptr_O(args.ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O), + ldq(args.ldq), + ldk(args.ldk), + ldv(args.ldv), + ldo(args.ldo), + causal(args.causal), + scale(args.scale) + { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) { + + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0, + args.problem_sizes1, + args.problem_count, + workspace, tile_count); + threadblock_count = args.threadblock_count; + ptr_Q = args.ptr_Q; + ptr_K = args.ptr_K; + ptr_P = args.ptr_P; + ptr_V = args.ptr_V; + ptr_O = args.ptr_O; + ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O; + ldq = args.ldq; + ldk = args.ldk; + ldv = args.ldv; + ldo = args.ldo; + causal = args.causal; + scale = args.scale; + } + }; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array out_rescale; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; + typename MM1::Mma::SharedStorage mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; + typename MM1::Mma::SharedStorage mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + using SharedStorage = typename cutlass::platform::conditional< + kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + +private: + + // Parameters to be used by an individual tile + struct TileParams { + + CUTLASS_HOST_DEVICE + static int query_start(int threadblock_idx) { + return threadblock_idx * kQueriesPerBlock; + } + + // Returns whether this threadblock computes within the number of queries, + // which is determined by the M dimension of problem 0 + CUTLASS_HOST_DEVICE + static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) { + return query_start(threadblock_idx) < problem_size0.m(); + } + + CUTLASS_HOST_DEVICE + static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) { + return problem_size0.m() - query_start(threadblock_idx); + } + + CUTLASS_HOST_DEVICE + static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) { + int nk = problem_size0.n(); + if (causal) { + nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk); + } + return nk; + } + + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + FMHAGrouped() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x; + } + + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.x / kThreadsPerWarp; + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x % kThreadsPerWarp; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + [[maybe_unused]] auto& si = shared_storage.after_mm0.si; + auto& mi = shared_storage.mi; + auto& out_rescale = shared_storage.out_rescale; + + ProblemVisitor problem_visitor( + params.problem_visitor, + shared_storage.problem_visitor, + blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + if (!TileParams::can_compute(threadblock_idx, problem_size0)) { + problem_visitor.advance(gridDim.x); + continue; + } + + const int32_t problem_idx = problem_visitor.problem_index(); + + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = ElementAccumulator(0); + out_rescale[thread_id()] = accum_t(1.0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + + ElementO *ptr_O = params.ptr_O[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; + ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; + const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]}, + ptr_O, + typename OutputTileIterator::TensorCoord{ + num_queries, problem_size1.n()}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]}, + ptr_O_accum, + typename OutputTileIteratorAccum::TensorCoord{ + num_queries, problem_size1.n()}, + thread_id(), + {0, col}); + }; + + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal); + + for (int32_t iter_key_start = 0; iter_key_start < num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + (int32_t)kKeysPerBlock, num_keys - iter_key_start); + int32_t const& problem_size_0_k = problem_size0.k(); + int32_t const& problem_size_1_n = problem_size1.n(); + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])}, + params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + MM1::Mma::prologue( + shared_storage.after_mm0.mm1, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx]; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])), + ptr_Q, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + {0, 0}); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])), + params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx], + {problem_size_0_k, problem_size_0_n}, + thread_id(), + {0, 0}); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), warp_id(), lane_id()); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } else { + MM1::Mma::drain_cp_asyncs(); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (warp_id() % MM0::Mma::WarpCount::kM), + (warp_id() / MM0::Mma::WarpCount::kM) + }; + + // Mask out last if causal + if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) { + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + int32_t last_col; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + // DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + // DISPATCH_BOOL( + // num_keys - iter_key_start >= kKeysPerBlock, + // kFullColumns, + // ([&] { + // // Update `mi` from accum stored in registers + // // Also does accum[i] <- exp(accum[i] - mi) + // iterative_softmax< + // typename MM0::Mma::Operator::IteratorC, + // kFullColumns, + // kIsFirst>( + // accum_o, + // accum, + // mi, + // m_prime, + // s_prime, + // lane_id(), + // thread_id(), + // warp_id(), + // num_keys - iter_key_start, + // iteratorC_tile_offset, + // kSupportsBias ? 1.0f : params.scale); + // })); + // })); + + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + out_rescale, + shared_storage.addition_storage, + lane_id(), + thread_id(), + warp_id(), + num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : params.scale); + + // Output results to shared-memory + int warp_idx_mn_0 = warp_id() % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kKeepOutputInRF ? 1 + : ceil_div( + (int64_t)problem_size_1_n, + int64_t(MM1::ThreadblockShape::kN)); + + // Iterate over the N dimension of GEMM1 + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])}, + params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + typename MM1::Mma mma_pv( + // operand A: Pij_dropped in shared memory + shared_storage.after_mm0.si.accum_ref(), + // operand B: shared memory staging area for Vj, which is loaded + // from global memory + shared_storage.after_mm0.mm1.operand_B_ref(), + (int)thread_id(), + (int)warp_id(), + (int)lane_id()); + + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + MM1::Mma::drain_cp_asyncs(); + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast::value, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + output_accum_t, + kIsFirst::value, + kIsLast::value, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast::value, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = gemm_kernel_utils::call_conditional< + kIsLast::value, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kKeepOutputInRF) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + MM1::Mma::drain_cp_asyncs(); + epilogue(rescale, dest_iter, accum_o); + } + + // Next tile + problem_visitor.advance(gridDim.x); + __syncthreads(); // Don't start the next iteration until all threads are done using shared memory. + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + cutlass::Array& out_rescale, + cutlass::Array& + addition_storage, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kThreadsPerWarp>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + + static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + // Doing this `exp` is quite expensive. Let's + // split it across the warps + bool restore_mi_to_minus_inf = false; + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + auto m_prime_id = m_prime[id]; + auto mi_id = mi[id]; + bool changed = m_prime_id < mi_id; // `false` if both are -inf + if (changed) { + auto m_prime_exp = exp2f(m_prime_id - mi_id); + out_rescale[id] = m_prime_exp; + s_prime[id] *= m_prime_exp; + } else { + // Only when bias is enabled, it's possible that all the first values + // of attention are masked to `-inf`. In that case we want to avoid + // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 + if (kSupportsBias && + mi_id == -cutlass::platform::numeric_limits::infinity()) { + restore_mi_to_minus_inf = true; + mi[id] = 0.0f; + } + out_rescale[id] = 1.0f; + } + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !is_first) { + accum_t line_rescale; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag_o[idx] = frag_o[idx] * line_rescale; + }, + [&](int accum_m) {}); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + // NOTE: we could atomically add `total_row` to `s_prime`, but + // it's faster (and deterministic) to avoid atomics here + addition_storage + [accum_m + kQueriesPerBlock * tile_offset.column()] = + total_row; + } + }); + } + + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + accum_t total_row = s_prime[id]; + if (restore_mi_to_minus_inf) { + // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` + mi[id] = -cutlass::platform::numeric_limits::infinity(); + } else { + m_prime[id] = mi[id]; + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..f88219304b8e239cbb656004ad17636ac98d4e04 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped FMHA +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +// Helper for correctly representing problem sizes in grouped kernels +template +struct FMHAGroupedProblemSizeHelper { + + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + // FMHA only partitions tiles across the M dimension. + return cutlass::gemm::GemmCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), 1, 1); + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {} + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { + return grid.m() * grid.n(); + } +}; + +} // namespace detail + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct FMHAGroupedProblemVisitor : public GroupedProblemVisitor< + detail::FMHAGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + + using ProblemSizeHelper = detail::FMHAGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params(): problem_sizes0(nullptr), problem_sizes1(nullptr), + problem_count(0), workspace(nullptr), tile_count(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const *problem_sizes0, + cutlass::gemm::GemmCoord const *problem_sizes1, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0 + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) + {} + + /// Convert the FMHA-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams(// Set problem_sizes as problem_sizes1 because these determine + // shape of the final output of FMHA + problem_sizes1, + problem_count, + workspace, + tile_count); + } + + }; + + // + // Methods + // + CUTLASS_DEVICE + FMHAGroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base ( + params_.to_base(), + shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) + {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..f3a1d4cbc275a166e0575b365413864ffb93c9f3 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h @@ -0,0 +1,124 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "custom_mma_multistage.h" +#include "custom_mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" + +template +struct MakeCustomMma; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + cutlass::arch::CacheOperation::Kind CacheOpA, + typename IteratorB, + typename SmemIteratorB, + cutlass::arch::CacheOperation::Kind CacheOpB, + typename ElementC, + typename LayoutC, + typename Policy, + int Stages, + cutlass::gemm::SharedMemoryClearOption SharedMemoryClear, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + Stages, + SharedMemoryClear>, + kMaxK> { + // Reduce the number of stages if we don't need that many + static int constexpr kStages = + kMaxK == cutlass::platform::numeric_limits::max() + ? Stages + : cutlass::const_min( + Stages, + (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); + using Mma = cutlass::gemm::threadblock::CustomMmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + kStages, + SharedMemoryClear, + kMaxK>; +}; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + typename IteratorB, + typename SmemIteratorB, + typename ElementC, + typename LayoutC, + typename Policy, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>, + kMaxK> { + using Mma = cutlass::gemm::threadblock::CustomMmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>; +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..66c099d15bf038fd06fd2b3ca6c03d1de3f7ffe9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template + struct OperandSharedStorage { + AlignedBuffer buffer; + using TensorRef = TensorRef; + + CUTLASS_DEVICE + static OperandLayout Layout() { + return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); + } + + /// Returns a TensorRef to the operand + CUTLASS_HOST_DEVICE + TensorRef ref() { + return TensorRef{buffer.data(), Layout()}; + } + }; + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape< + Shape::kM + Policy::SmemPaddingA::kRow, + Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + using SharedStorageA = OperandSharedStorage< + typename Operator::ElementA, + ShapeA, + typename Operator::LayoutA>; + using SharedStorageB = OperandSharedStorage< + typename Operator::ElementB, + ShapeB, + typename Operator::LayoutB>; + using TensorRefA = typename SharedStorageA::TensorRef; + using TensorRefB = typename SharedStorageB::TensorRef; + + struct SharedStorage { + /// Buffer for A operand + SharedStorageA operand_A; + + /// Buffer for B operand + SharedStorageB operand_B; + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorageA& shared_storageA, + SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), + warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..145315e413172e1f29a10b13a9d9d7fa19537006 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h @@ -0,0 +1,760 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Upper boundon the K dimension + int kMaxK = cutlass::platform::numeric_limits::max(), + /// Used for partial specialization + typename Enable = bool> +class CustomMmaMultistage : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireMat ? Stages : Stages - 1; + + private: + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + bool prologue_done_; + + // Set to `True` to ensure the accumulator will be zero outside the GEMM + // footprint + bool zero_outside_bounds_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx), + prologue_done_(false), + zero_outside_bounds_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaMultistage( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { + prologue_done_ = value; + return true; + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { + zero_outside_bounds_ = value; + return true; + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); + SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); + int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; + _prologue( + iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index( + group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index( + group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + template + CUTLASS_DEVICE static void _prologue( + IteratorA& iterator_A, + IteratorB& iterator_B, + int32_t& gemm_k_iterations, + SmemIteratorA& smem_iterator_A_, + SmemIteratorB& smem_iterator_B_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + if (kLoadA) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + } + + ++iterator_A; + } + + ++smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (kLoadB) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } + + ++iterator_B; + } + + ++smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + if (!prologue_done_) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else if (!kSmemContainsEntireMat) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else { + gemm_k_iterations -= kNumStagesConcurrentLoad; + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform( + warp_transformed_frag_A[0], + warp_transformed_frag_B[0], + warp_loaded_frag_A[0], + warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // In case of a non-circular buffer ("kSmemContainsEntireMat") + // make sure we don't load out of bounds data. + if (!kSmemContainsEntireMat || + gemm_k_iterations > (-kNumStagesConcurrentLoad) || + warp_mma_k < Base::kWarpGemmIterations - 1) { + this->warp_tile_iterator_A_.load( + warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform( + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma( + tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (!kSmemContainsEntireMat && + warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + if (!kSmemContainsEntireMat) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (!kSmemContainsEntireMat && + smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform( + warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..b967b86c0131ada37a19be623486c61562588246 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h @@ -0,0 +1,401 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaPipelined : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + static bool const kSmemContainsEntireMat = false; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaPipelined( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaPipelined( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { + // NOT IMPLEMENTED FOR PIPELINED + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { + // NOT NEEDED FOR PIPELINED + // shared memory will always be zero-filled + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + // NOT IMPLEMENTED FOR PIPELINED + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + TransformA transform_A = + TransformA(), ///< transformation applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency + // requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma( + accum, + warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/find_default_mma.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/find_default_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..560da450ff54a86e723e11b9aa00597095b73127 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/find_default_mma.h @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Cutlass provides helper template functions to figure out the right + datastructures to instantiate to run a GEMM with various parameters (see + `cutlass/gemm/threadblock/default_mma.h`). However, due to template + instantiation priority rules, it will only create an MmaMultiStage with + kStages=3 (otherwise creates an MmePipelined - which is not compatible with + FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, + so we just copy-pasted some code from `default_mma.h` and + `default_mma_core.h` files and wrapped this template to allow our usecase. + + This is really only for the FastF32 case - aka using TensorCores with fp32. +*/ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + typename Enable_ = void> +struct FindDefaultMma { + static constexpr bool AccumulatorsInRowMajor = false; + static constexpr SharedMemoryClearOption SharedMemoryClear = + SharedMemoryClearOption::kNone; + using DefaultMma = cutlass::gemm::threadblock::DefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator, + AccumulatorsInRowMajor, + SharedMemoryClear>; +}; + +/// Specialization for sm80 / FastF32 / multistage with kStages=2 +template < + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + int kStages, + typename Operator> +struct FindDefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { + using LayoutC = layout::RowMajor; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm80; + + using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + 3, + Operator>; + struct DefaultMma : DefaultMma_ { + using MmaCore_ = typename DefaultMma_::MmaCore; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore_::Shape, + typename DefaultMma_::IteratorA, + typename MmaCore_::SmemIteratorA, + MmaCore_::kCacheOpA, + typename DefaultMma_::IteratorB, + typename MmaCore_::SmemIteratorB, + MmaCore_::kCacheOpB, + ElementAccumulator, + LayoutC, + typename MmaCore_::MmaPolicy, + kStages>; + }; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..7692389c5c9bd77e68cc622d3cc774a7c605f1b5 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h @@ -0,0 +1,378 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/functional.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. +*/ + +template +struct AccumLambdaIteratorSm80 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord( + quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AccumLambdaIteratorSm70 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord( + accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + static_assert( + cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; + ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; + ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AccumLambdaIteratorSimt { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = + mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + static_assert( + cutlass::platform::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, + Policy::LaneMmaShape::kN); + return lane_offset + + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultMmaAccumLambdaIterator; + +// Simt +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>; + using Iterator = AccumLambdaIteratorSimt; +}; + +// TensorOp - Volta +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Iterator = AccumLambdaIteratorSm70; +}; + +// TensorOp - Sm75+ +template < + typename S1, + typename S2, + typename S3, + typename accum_t, + int kWarpSize> +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>; + using Iterator = AccumLambdaIteratorSm80; +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..f2b94d003119b144cea632b1302bfe024d15806d --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h @@ -0,0 +1,1955 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tools and utils to store a GEMM output in shmem, and to use that + output as operandA for another GEMM back-to-back +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/vector_iterator.h" + +#include "../epilogue/epilogue_thread_apply_logsumexp.h" +#include "../gemm/mma_accum_lambda_iterator.h" +#include "../gemm_kernel_utils.h" +#include "../iterators/default_warp_iterator_from_smem.h" +#include "../iterators/make_residual_last.h" +#include "../iterators/transpose_warp_iterator.h" +#include "../iterators/warp_iterator_from_smem.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Shared storage object needed by accumulator +/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +template < + typename Shape_, + typename Element_, + typename Layout_, + typename Padding_> +class AccumulatorSharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = cutlass::TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = cutlass:: + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for accumulator + cutlass::AlignedBuffer accum; + + public: + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { + return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); + } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { + return TensorRefAccum{accum.data(), LayoutAccum()}; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // Maximum K dimension - also the dimension of the shared-memory + // holding `OperandA` + int kMaxK_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Layout in shared-memory of operand A + typename SmemLayoutA, + /// Used for partial specialization + typename Enable = bool> +class MmaBaseFromSharedMemory { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + static constexpr int kMaxK = kMaxK_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + using WarpCount1 = WarpCount; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = kWarpGemmIterations; + + /// Number of stages + static int const kStages = Stages; + + /// If this is true, we fill the entire shmem buffer at start + /// and don't need to iterate through it in a circular fashion + static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + public: + // + // Data members + // + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + // + // Methods + // + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + // /// Iterator to load a warp-scoped tile of A operand from shared memory + // typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaBaseFromSharedMemory( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + TensorRefB& b_tile, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_B_(b_tile, lane_idx) {} +}; + +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { + public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // Since arrays of zero-sized objects are not allowed, using size as 1. + // The compiler will most likely wipe it out anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset( + typename TensorRef::TensorCoord const&) { + return *this; + } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { + return *this; + } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { + public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) { + Fragment converted_scale_frag = cutlass::NumericArrayConverter< + typename Fragment::Element, + typename FragmentScale::Element, + FragmentScale::kElements>()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { + public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { + return frag; + } +}; +} // namespace + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // BEGIN smem + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + /// Max GEMM problem size in K dimension + int MaxK, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< + Shape_, + MaxK, + Policy_, + 2, + typename WarpIteratorA_::Layout> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape_, + MaxK, + Policy_, + 2, + typename WarpIteratorA_::Layout>; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + using WarpIteratorA = WarpIteratorA_; + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorB = SmemIteratorB_; + + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + + using WarpFragmentB = typename Operator::FragmentB; + + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpFragmentA, + WarpFragmentAScale, + ScaleOperandA>; + + protected: + // /// Iterator to write threadblock-scoped tile of A operand to shared memory + // SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to load a warp-scoped tile of A operand from intermediate + /// accumulator tile + WarpIteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::TensorRefA a, // Operand A in shared memory + typename Base::TensorRefA a_scale, // Operand A_scale in shared memory + typename Base::TensorRefB + b_staging, // staging memory for loading tiles of B + int thread_idx, + int warp_idx, + int lane_idx) + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + warp_tile_iterator_A_scale_(a_scale, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::TensorRefA a, ///< Operand A in shared memory + typename Base::TensorRefB b_staging, ///< staging memory for loading B + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx) ///< ID of each thread within a warp + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + // For API compatibility with MmaMultistageFromSharedMemory + // but not supported as it worsens perf: older gpus < sm80 don't + // support async transfers and have to waste registers + CUTLASS_DEVICE + void set_prologue_done(bool value) {} + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) {} + + CUTLASS_DEVICE + static void drain_cp_asyncs() {} + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + // IteratorA iterator_A, ///< iterator over A + // operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + // TransformA transform_A = TransformA(), ///< transformation + // applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentB tb_frag_B; + + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_B.set_residual_tile(gemm_k_iterations == 1); + iterator_B.load(tb_frag_B); + + ++iterator_B; + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_B_; + + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; + WarpFragmentB warp_frag_B[2]; + warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); + warp_frag_B[0].clear(); + + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_B.set_residual_tile(gemm_k_iterations == 2); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency + // requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + bool hasNext = true; + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + if (gemm_k_iterations > 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + } + + __syncthreads(); + + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory SMEM: Don't reset iterator A, as + // we are continuing our iteration at this point + if (smem_write_stage_idx == 1) { + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + hasNext = gemm_k_iterations > 1; + } + + // Only read the next if we need to + if (hasNext) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load( + warp_frag_A_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_B.load(tb_frag_B); + + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B.set_residual_tile(gemm_k_iterations == 3); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + } + + warp_mma( + accum, + FragmentAScaler::apply( + warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]), + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages_, + int kMaxK_, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout>; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + using IteratorB = IteratorB1; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate + ///< accumulator tile in shared memory + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + using FragmentC = FragmentC1; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / + Base::kWarpGemmIterations1; + }; + + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireB ? Base::kStages : Base::kStages - 1; + + private: + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpLoadedFragmentA1, + WarpLoadedFragmentA1Scale, + ScaleOperandA>; + + private: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + bool prologue_done_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::TensorRefA a, + typename Base::TensorRefA a_scale, + typename Base::TensorRefB b_tile, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + warp_tile_iterator_A1_scale_(a_scale, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::TensorRefA a, + typename Base::TensorRefB b_tile, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + CUTLASS_DEVICE + void set_prologue_done(bool value) { + prologue_done_ = value; + } + + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) { + SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); + _prologue( + iterator_B1, + (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, + smem_iterator_B1); + } + + CUTLASS_DEVICE + static void drain_cp_asyncs() { + // commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1& iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index( + group_start_B1 * IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + CUTLASS_DEVICE + static void _prologue( + IteratorB& iterator_B1, + int32_t gemm_k_iterations_1, + SmemIteratorB1& smem_iterator_B1_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_1_, + ///< destination accumulator tile + FragmentC1& accum, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC1 const& src_accum) { + // 2nd Gemm + + // + // Prologue + // + // Perform accumulation in the 'd' output operand + accum = src_accum; + + if (!prologue_done_) { + _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); + } else if (!kSmemContainsEntireB) { + // Restore the iterators increments + + int gemm_k_iterations_1 = gemm_k_iterations_1_; + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + iterator_B1.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); + iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma1.transform( + warp_transformed_frag_A1[0], + warp_transformed_frag_B1[0], + FragmentAScaler::apply( + warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), + warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC1 tmp_accum; + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); + gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + // Load warp-level tile from accumulator fragment (A) + // or shared memory (operand B) + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations1); + // skip warp tile loading for the last kgroup (we are out of the buf) + if (gemm_k_iterations_1 > (-Base::kStages + 2) || + warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma1.transform( + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), + warp_loaded_frag_B1[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma1( + tmp_accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (!kSmemContainsEntireB) { + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + } + + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform( + warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +// Converts a "regular" Mma into their counterpart from shared memory +template < + typename Mma_, + int kMaxK, + typename WarpIteratorA_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA = false> +struct DefaultMmaFromSharedMemory; + +// Mma pipelined +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + typename WarpIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_, + /// Transformation applied to B operand + typename TransformB_, + // Max MMA problem size K + int kMaxK, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>, + kMaxK, + WarpIteratorA_, + kScaleOperandA, + kTransposeA> { + using RegularMma = MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using ArchMmaOperator = typename Policy_::Operator; + + static constexpr bool kIsTransposedA = false; + using WarpIteratorA = WarpIteratorA_; + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + + using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + kMaxK, + IteratorB, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_>; +}; + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + typename WarpIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + int kMaxK, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>, + kMaxK, + WarpIteratorA_, + kScaleOperandA, + kTransposeA> { + using RegularMma = MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using WarpIteratorTranspose = TransposeWarpIterator; + static constexpr bool kIsTransposedA = + WarpIteratorTranspose::kSupportsTranspose && kTransposeA; + using WarpIteratorA = typename platform::conditional< + kIsTransposedA, + typename WarpIteratorTranspose::Iterator, + WarpIteratorA_>::type; + + // Reduce the number of stages if we don't need that many + static int constexpr kStagesMax = + (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); + static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); + + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + using Mma = + typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + IteratorB, + SmemIteratorB_, + RegularMma::kCacheOpB, + ElementC_, + LayoutC_, + Policy_, + kStages, + kMaxK>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename IteratorC, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm; + +// Tensor Cores >= Sm75 specialization (Ampere ...) +template < /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>; + using FragmentC = typename IteratorC::Fragment; + using InstructionShape = InstructionShape_; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using accum_t = Element_; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + // Iterator to load accumulators (results of matmul in registers) + using FragmentIteratorAccumulator = + cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape, + InstructionShape, + accum_t, + typename Operator::Policy::Operator::FragmentC, + cutlass::layout::RowMajor>; + + // Iterator to store to shared-memory + using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + scalar_t, // accum_t, + SmemAccumulatorLayout>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + typename SmemIteratorD0::Element, + typename SmemIteratorD0::TensorLayout, + typename SmemIteratorD0::Padding>; + // We need to provide an operation for the epilogue. Let's create an + // operation that does nothing (ScaleType::Nothing), just converts + // from accum_t (float) -> scalar_t (can be half) + using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< + typename SmemIteratorD0::Element, // ElementOutput + FragmentIteratorAccumulator::Fragment::kElements, + accum_t, // ElementAccumulator + typename SmemIteratorD0::Element, // ElementCompute + cutlass::epilogue::thread::ScaleType::Nothing>; + using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + SmemIteratorD0, // ScaleBiasIterator - not used + OutputOpNoOp>; + + // Epilogue 2: with LSE (for backwards pass) + static int const kElementsPerAccess = 2; // TODO: Why 2? + using IteratorAccumulatorLSE = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + // Shape + cutlass::MatrixShape, + // WarpShape + cutlass::MatrixShape, + lse_scalar_t, + cutlass::layout::RowMajor, + kElementsPerAccess>>; + using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< + scalar_t, // ElementOutput_ + lse_scalar_t, // ElementLSE_ + accum_t, // ElementAccumulator_ + accum_t, // ElementCompute_ + 128 / cutlass::sizeof_bits::value + // FragmentIteratorAccumulator::Fragment::kElements + // InstructionShape::kM * InstructionShape::kN / 32 + >; + using EpilogueWithLSE = + cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + IteratorAccumulatorLSE, + EpilogueOpApplyLSE>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + Epilogue epilogue; + epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC& accum, + lse_scalar_t const* lse, + int32_t lse_extents, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + constexpr int32_t kAlignLSE = 32; + IteratorAccumulatorLSE iterator_lse( + lse, + {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, + thread_id, + warp_id, + cutlass::MatrixCoord{0, 0} // offset + ); + + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + EpilogueWithLSE epilogue; + EpilogueOpApplyLSE minus_lse_exp({}); + epilogue( + minus_lse_exp, + smem_iterator_attn, + accum, + // scale - unused + iterator_lse, + // bias + iterator_lse); + } +}; + +// Volta Specialization +// only supported for f16 +template +struct B2bGemm< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>, + Operator, + cutlass::half_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>; + using scalar_t = cutlass::half_t; + using accum_t = IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using SmemAccumulatorLayout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + SmemAccumulatorLayout, + cutlass::MatrixShape<0, 0> // Padding + >; + using TensorRef = cutlass::TensorRef; + using Policy = typename IteratorC::Policy; + using Element = accum_t; + // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields + // Let's copy their values + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // ctor - from MmaVoltaTensorOpAccumulatorTileIterator + TensorRef ref_(shared_storage.accum_ref()); + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + cutlass::MatrixCoord lane_offset(accum_m, accum_n); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + using AccessType = cutlass::Array; + + // store - from MmaVoltaTensorOpAccumulatorTileIterator + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2; + int r = (accum_m + lane_offset.row()); + AccessType to_store; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + int c = (accum_n + n + lane_offset.column()); + to_store[n] = scalar_t(accum[idx]); + } + int c = (accum_n + lane_offset.column()); + assert(r < 32); + assert(c < 32); + *reinterpret_cast( + ref_.data() + ref_.offset({r, c})) = to_store; + } + } + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : cutlass::platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +// Simt Specialization +// for f32 on Sm70-Sm75 and f16/f32 below + +template < + typename Operator, + typename OperatorPolicy, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>; + using accum_t = typename IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = typename IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::ColumnMajor, + cutlass::MatrixShape<0, 0> // Padding + >; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + using Policy = typename IteratorC::Policy; + using Element = typename IteratorC::Element; + using Iterations = typename IteratorC::Iterations; + using Delta = typename IteratorC::Delta; + + auto ref_ = shared_storage.accum_ref(); + // ctor - MmaSimtTileIterator + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + + ref_.add_coord_offset(lane_offset); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + // store - MmaSimtTileIterator + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int r = + Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + + m; + int c = mma_n * Delta::kColumn + n; + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : cutlass::platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3703257a1730852809e3f5597d134f8146e1f9e2 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h @@ -0,0 +1,258 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/arch/mma.h" + +//////////////////////////////////////////////////////////////////////////////// +// Some helper functions +//////////////////////////////////////////////////////////////////////////////// +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (query.scalar_type() == at::ScalarType::Float) { \ + using scalar_t = float; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + func(); \ + } else { \ + XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ + } \ + } + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + using BOOL_NAME = std::true_type; \ + F(); \ + } else { \ + using BOOL_NAME = std::false_type; \ + F(); \ + } \ + } + +#define DISPATCH_ARCHTAG(CC, func) \ + { \ + if (CC >= 80) { \ + using ArchTag = cutlass::arch::Sm80; \ + func(); \ + } else if (CC >= 75) { \ + using ArchTag = cutlass::arch::Sm75; \ + func(); \ + } else if (CC >= 70) { \ + using ArchTag = cutlass::arch::Sm70; \ + func(); \ + } else if (CC >= 50) { \ + using ArchTag = cutlass::arch::Sm50; \ + func(); \ + } else { \ + XFORMERS_CHECK( \ + false, \ + "Your device is too old. We require compute capability >= 50"); \ + } \ + } + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#ifdef TORCH_CHECK +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + XFORMERS_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") +#define XFORMERS_CHECK TORCH_CHECK +#elif defined(__CUDACC_RTC__) +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + return false; \ + } +#else +#include +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + std::cerr << #PTR " is not correctly aligned\n"; \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::cerr << "'" #COND "' failed: " << ERR << "\n"; \ + return false; \ + } +#endif + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + XFORMERS_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace gemm_kernel_utils { + +template +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) { + return ((n + m - 1) / m) * m; +} + +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below +template +struct DefaultGemmType { + static constexpr int ThreadK = 8; + static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using OpClass = cutlass::arch::OpClassSimt; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f32 +template +struct DefaultGemmType< + ArchTag, + float, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 80>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAddFastF32; +}; + +// Specialization for tensorcores with f16/bf16 - Sm75+ +template +struct DefaultGemmType< + ArchTag, + scalar_t, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 75 && + cutlass::sizeof_bits::value == 16>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Enables to do +// `auto x = kCondition ? fa(arg) : fb(arg)` +// when `fa` and `fb` have different types +template +struct call_conditional; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(ta(arg)) { + return ta(arg); + } +}; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(tb(arg)) { + return tb(arg); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Mark a variable as warp-uniform - enables some compiler optimizations +// The cheapest way to do it is just to broadcast it from lane 0 +//////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE T warp_uniform(T value) { + struct { + union { + T value; + uint32_t asInt; + }; + } p; + p.value = value; + p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0); + return p.value; +} + +template +CUTLASS_DEVICE T* warp_uniform(T* ptr) { + struct { + union { + T* ptr; + uint32_t asInt[2]; + }; + } p; + p.ptr = ptr; + p.asInt[0] = warp_uniform(p.asInt[0]); + p.asInt[1] = warp_uniform(p.asInt[1]); + return p.ptr; +} +} // namespace gemm_kernel_utils diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..930ee46dfe6d1ab8e9482145b0379de8a7eddc6f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Instantiates the right WarpIterator to read from shared memory + The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading + data dumped with `B2bGemm::accumToSmem`. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" +#include "cutlass/platform/platform.h" + +#include "warp_iterator_from_smem.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + typename WarpShape, + typename InstructionShape, + typename RegularWarpIterator, + typename Policy, + typename Enable = void> +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, kInstrK>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>; + + using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::MatrixShape>; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 16, 4>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<1, 1, 1>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..7a52e96a36bd5cddf5db65fba130b4b45ff2ff66 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h @@ -0,0 +1,751 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue iterator that supports prefetching + + Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in +/// epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | +/// ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + bool UseCUDAStore = false> +class PredicatedTileIteratorPrefetch { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( + ThreadMap::Iterations::kRow > 0, + "ThreadMap::Iterations::kRow must be > 0"); + static_assert( + ThreadMap::Iterations::kGroup > 0, + "ThreadMap::Iterations::kGroup must be > 0"); + static_assert( + ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert( + ThreadMap::Iterations::kColumn > 0, + "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert( + sizeof(PredicatedTileIteratorParams::stride) == 8, + "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPrefetch( + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) { + TensorCoord thread_offset = + ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < + extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void prefetch_all() { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kIterations; ++iter) { + prefetch(); + ++(*this); + } + } + + CUTLASS_DEVICE + void prefetch() { + uint8_t* byte_pointer = byte_pointer_; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + // on windows using unsigned long here gives the error + // error: asm operand type size(4) does not match + // type/size implied by constraint 'l' + uint64_t addr = (uint64_t)((void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { + store_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) + row_add_P = 0; + if (output_Q > convolution_Q - 2) + row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPrefetch& operator++() { + ++state_[0]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_row; + } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * + ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { + mask_ = mask; + } +}; + +template +struct MakePrefetchableIterator { + using Iterator = PredicatedTileIteratorPrefetch< + typename IT::ThreadMap, + typename IT::Element>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..a667d675276111d074abd868e96c141d2f2f3829 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "predicated_tile_access_iterator_residual_last.h" +#include "predicated_tile_iterator_residual_last.h" + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessSize, + Gather>; +}; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType, + Gather>; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..d007f0445b8301a16e371c6f1d9911dfda5ad915 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h @@ -0,0 +1,2114 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorResidualLast +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather = false> +class PredicatedTileAccessIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear +/// data. +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : Base( + layout.stride(0), + MakePredicatedTileAccessIteratorDesc< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap>()()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset separated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + /// + + /// Gather indices + int const* indices_; + + Index gather_offset_strided; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + indices_(indices) { + the_predicates.set_predicates(thread_id, threadblock_offset); + the_predicates.get_mask(residual_tile_mask); + + // Working around a weird compiler bug happening on P100 for the backward. + // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) + // residual_tile_mask[0] = 15 (correct) + // + // Adding prints when the value is calculated (in `compute_predicates_`) + // sometimes removes the bug. The consequence is that we skip some + // element of a tensor, leading to wrong results + // Setting `compute_predicates_`'s second argument (`is_steady_state`) to + // true also seems to get rid of the bug - at the cost of twice as many + // comparisons. +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) + constexpr bool kWorkAroundCompilerBug = false; +#else + constexpr bool kWorkAroundCompilerBug = true; +#endif + the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + gather_offset_strided = the_predicates.thread_offset_.strided(); + add_pointer_offset( + layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + } + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (!Gather) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + gather_offset_strided += Shape::kStrided * tile_offset.strided(); + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + if (Gather) { + assert(indices_); + + if (!valid()) { + return nullptr; + } + + LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value / + 8) + + the_predicates.iteration_vector_; + int strided_index = gather_offset_strided + + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + + LongIndex strided_offset = indices_[strided_index] * + LongIndex(params_.stride_) * sizeof_bits::value / 8; + + return reinterpret_cast( + pointer_ + contiguous_offset + strided_offset); + } + + return reinterpret_cast( + pointer_ + + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * + sizeof_bits::value) / + 8) + + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather) { + pointer_ += params_.inc_strided_; + } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, + // this subtraction as well as the subsequent integer addition are both + // elided by the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType, + Gather>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + layout::PitchLinear, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorResidualLast; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + // Default ctor + CUTLASS_HOST_DEVICE + Params() + : stride_(0), + inc_contiguous_(0), + inc_strided_(0), + inc_next_(0), + inc_advance_(0) {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = + (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = inc_strided_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) * + sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = + Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - + LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent) { + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(pointer_) + + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// interleaved data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..fa40d850c84528c7cf2e5d9eaa5961774843405d --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h @@ -0,0 +1,2119 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 + tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorResidualLast +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize +/// register liveness and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" +/// object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is +/// constructed. Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator +/// is constructed. Subsequent additions to logical coordinate offset may be +/// performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be +/// partially full in both the advance dimension and the steady-state dimension. +/// This is assumed to be the last tile in the iteration sequence. Advancing an +/// iterator that has just been constructed moves to the first tile that is full +/// in the advance dimension and recomputes predicates. Subsequent accesses may +/// be performed without updating internal predicates and are efficient in terms +/// of live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced +/// at least once outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to +/// dereferencing the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update +// internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - +// subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to +// steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = +// transform::threadblock::PredicatedTileIteratorResidualLast; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess, + bool Gather = false> +class PredicatedTileIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset, + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< Gather indices + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 +/// data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..18858ab7327a351d5bb742acb09da72bfccedf30 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "warp_iterator_from_smem.h" + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + typename InstructionShape, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp:: + WarpIteratorFromSmem> { + using Iterator = cutlass::gemm::warp:: + WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..3f4ebec698b23f85f7d0a0a5450aa21f3a9fc17e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + + This is only implemented for the specific shapes. +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + typename InstructionShape_, + bool kTranspose = false> +class WarpIteratorFromSmem { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + static_assert( + kOperand == Operand::kA, + "No support for OperandB at the moment"); + + /// Basic check + static_assert( + kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16"); + static_assert( + InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16, + "Only supports 16x8x8 / 16x8x16"); + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 + : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape< + Shape::kRow / InstructionShape::kRow, + Shape::kColumn / InstructionShape::kColumn>; + + static int const kIterations = (kOperand == Operand::kA) + ? InstructionCount::kColumn + : InstructionCount::kRow; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + (kOperand == Operand::kA) + ? (Shape::kRow* InstructionShape::kColumn / kThreads) + : (Shape::kColumn* InstructionShape::kRow / kThreads)>; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn + : InstructionShape::kRow); + static int constexpr kAccessesInner = + (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + // Number of 32bits tiles to load per `ldmatrix` + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8"); + + private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + + public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {} + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) { + // See also: + // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688 + // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4) + // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4) + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert( + InstructionCount::kRow * kTilesPerInstruction == 4, + "can't use ldmatrix.x4"); + int access_m_idx = ldsm_vec_num % kTilesPerInstruction; + int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner; + int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner); + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } else { + // Note: This is not tested or used + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; + ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset( + inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) { + TensorCoord coord_offset( + tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + if (kTranspose) { + coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; + } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() { + iterations_++; + + if (iterations_ >= kIterations) + advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = typename platform:: + conditional::type; + + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < + (InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4; + ++access_m_idx) { + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord( + access_m_idx * 16, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + cutlass::arch::ldsm( + access_ptr[access_m_idx], ref_.data() + ref_.offset(offset)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_backward.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..5cdb7c2145488c585a6665bb6170b3638e28692a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_backward.h @@ -0,0 +1,2554 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#ifdef HAS_PYTORCH +#include +#include +#include +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/fast_math.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "debug_utils.h" +#include "gemm_kernel_utils.h" + +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/integer_subbyte.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "epilogue/epilogue_pipelined.h" +#include "iterators/epilogue_predicated_tile_iterator.h" + +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_accum_lambda_iterator.h" +#include "gemm/mma_from_smem.h" +#include "transform/tile_smem_loader.h" + +using namespace gemm_kernel_utils; + +namespace { + +template +struct GmemTile { + /* + Helper functions to efficient store/load RF to gmem + + GEMM accumulators have a particular format on A100, and + it takes some compute/shared-memory to rearrange them to + a RowMajor or ColumnMajor format in global memory through + an Epilogue. The same complexity goes for loading into RF. + + This class loads/stores RF as they are, and can be used for + efficient accumulation across gemms for instance: + + ``` + GmemTile tile; + for (int i = 0; i < N; ++i) { + // ... + + Fragment accum; + if (i == 0) { + accum.clear(); + } else { + tile.load(accum); + } + mma(accum, ...); + if (i < N-1) { + // Store for next GEMM + tile.store(accum); + } else { + // Store in tensor (eg RowMajor) + epilogue(accum); + } + + // ... + } + ``` + */ + + // 128bits per thread + using AccessType = cutlass::Array; + static constexpr int32_t kBytes = sizeof(AccessType); + static constexpr int32_t kStride = kNumThreads * AccessType::kElements; + static constexpr int32_t kNumIters = + FragmentType::kElements / AccessType::kElements; + static constexpr int32_t kElementsStored = + kNumThreads * FragmentType::kElements; + static_assert( + FragmentType::kElements % AccessType::kElements == 0, + "fragment not aligned on 128 bits"); + + float* ptr; + + CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + cutlass::arch::global_load( + sub_fragment, gmem_ptr, true); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + fragment[i * AccessType::kElements + j] = sub_fragment[j]; + } + } + } + + CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + sub_fragment[j] = fragment[i * AccessType::kElements + j]; + } + cutlass::arch::global_store( + sub_fragment, gmem_ptr, true); + } + } + + CUTLASS_DEVICE void storeAtomicAdd( + FragmentType const& fragment, + int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + float val = fragment[i * AccessType::kElements + j]; + float* ptr = gmem_ptr + j; + atomicAdd(ptr, val); + } + } + } +}; + +struct AtomicLock { + CUTLASS_DEVICE static void acquire( + int32_t* lock, + int set_val, + int thread_id) { + if (thread_id == 0) { + while (atomicCAS(lock, 0 /*cmp*/, set_val /*setval*/) != set_val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + __nanosleep(40); +#endif + } + } + __syncthreads(); + } + CUTLASS_DEVICE static void release(int32_t* lock, int thread_id) { + if (thread_id == 0) { + int status = 0; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.gpu.b32 [%0], %1;\n" + : + : "l"(lock), "r"(status)); +#else + asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); +#endif + } + } +}; + +template +constexpr int getWarpsPerSmBw() { + bool is_half = !cutlass::platform::is_same::value; + if (Arch::kMinComputeCapability >= 80) { + return is_half ? 12 : 8; + } + return 8; +} +} // namespace + +template < + // which arch we target (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // input/output type + typename scalar_t_, + // run optimized kernel because memory accesses will be aligned + bool kIsAligned_, + // use dropout if enabled + bool kApplyDropout_, + // when doing a GEMM, preload the next one (uses more shmem) + bool kPreload_, + // block dimensions + int kBlockSizeI_, + int kBlockSizeJ_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // assumes that `cu_seqlen` is None, and + // (1) `num_queries % kBlockSizeI == 0` + // (2) `num_keys % kBlockSizeJ == 0` + bool kKeysQueriesAlignedToBlockSize_ = false, + // Allows to parallelize across keys + bool kEnableSplitKeys_ = true> +struct AttentionBackwardKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + using scalar_t = scalar_t_; + using output_t = scalar_t; + using output_accum_t = float; + using lse_scalar_t = float; + using accum_t = float; + using ArchTag = ArchTag_; + static constexpr bool kIsAligned = kIsAligned_; + static constexpr bool kApplyDropout = kApplyDropout_; + static constexpr bool kPreload = kPreload_; + static constexpr int kBlockSizeI = kBlockSizeI_; + static constexpr int kBlockSizeJ = kBlockSizeJ_; + static constexpr int kMaxK = kMaxK_; + static constexpr bool kKeysQueriesAlignedToBlockSize = + kKeysQueriesAlignedToBlockSize_; + + static constexpr int64_t kWarpSize = 32; + + // If this is true, we store and accumulate dK/dV in RF + // rather than going back to gmem everytime + static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; + static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; + static_assert( + !kPreload || + (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF), + "preload MMA not supported"); + static constexpr bool kPrologueQK = kPreload; + static constexpr bool kPrologueGV = kPreload; + static constexpr bool kPrologueDOV = kPreload; + static constexpr bool kPrologueGQ = kPreload; + static constexpr bool kPrologueGK = kPreload; + + static constexpr int64_t kNumWarpsPerBlock = + (kBlockSizeI * kBlockSizeJ) / (32 * 32); + + // Compute delta for the f16 kernels + // TODO: Figure out why it's slower on the f32 kernels + // (something due to RF pressure?) + // TODO: Remove condition on `kOutputInRF` - this is needed to work + // around a compiler bug on V100, not exactly sure why but I spent + // too much time on this already. Reproducible with + // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance + static constexpr bool kKernelComputesDelta = + kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); + + // Launch bounds + static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int64_t kMinBlocksPerSm = + getWarpsPerSmBw() / kNumWarpsPerBlock; + + using GemmType = DefaultGemmType; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + typename GemmType::OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr auto kOptimalAlignement = cutlass::platform::max( + DefaultConfig::kAlignmentA, + DefaultConfig::kAlignmentB); + static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment; + + struct MatmulQK { + /* + attn_T = k_j @ q_i.transpose(-2, -1) # matmul + attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, + -1)).exp() # epilogue + + with attn_T.shape = (kBlockSizeJ, kBlockSizeI) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + accum_t, // ElementC + cutlass::layout::RowMajor, // LayoutC + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + DefaultConfig::kStages, + typename GemmType::Operator, + false, // AccumulatorsInRowMajor = false, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using MmaCore = typename DefaultMma::MmaCore; + using Mma = + typename MakeCustomMma::Mma; + + // used for efficient load of bias tile (Bij) from global memory to shared + // memory + using BiasLoader = TileSmemLoader< + scalar_t, + // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded + // row-major but needs to have transposed shape so we get the same + // elements. + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradV { + /* + grad_v[j_start:j_end] += attn_T @ do_i # matmul + + Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) + (we might need to iterate multiple times on K) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + // if dropout: + // for computing dVj += (Pij.T * Zij) @ dOi + // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of + // Pij.T are loaded in. The reason we do it this way is because Pij.T and + // Zij are reused in later steps, while Pij_dropped.T is only needed in + // this step. computing Pij_dropped.T on the fly allows us to avoid + // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the + // same time. + // if no dropout: + // for computing dVj += Pij.T @ dOi + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Operator:: + InstructionShape, // InstructionShape + typename DefaultGemm::Mma::Operator:: + IteratorA, // RegularWarpIterator + typename DefaultGemm::Mma::Policy // Policy + >::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulQK::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, + kApplyDropout>; // kScaleOperandA + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + struct MatmulDOIVJ { + /* + doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul + tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue? + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + + using ElementC = output_t; + using ElementAccum = accum_t; + + // no-op output op - epilogue just stores result to global memory + using BiasGradEpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + ElementC, + DefaultConfig::EpilogueOutputOp::kCount, + typename DefaultConfig::EpilogueOutputOp::ElementAccumulator, + typename DefaultConfig::EpilogueOutputOp::ElementCompute, + cutlass::epilogue::thread::ScaleType::Nothing>; + + using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + ElementC, // ElementC + cutlass::layout::RowMajor, // LayoutC + ElementAccum, // ElementAccumulator + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + BiasGradEpilogueOutputOp, // EpilogueOutputOp + void, // ThreadblockSwizzle (not used) + // multiple preloads, dropout Zij tile, and 3 stages push us over shared + // memory capacity on A100. set a ceiling on number of stages to save + // shared memory if dropout is in use. + kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64) + ? cutlass::const_min(2, DefaultConfig::kStages) + : DefaultConfig::kStages, // Stages + false, // SplitKSerial + typename GemmType::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using Mma = typename MakeCustomMma::Mma; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + ElementAccum, + kWarpSize>::Iterator; + + // epilogue used to write bias gradient, which is just the output of this + // matmul with some operations applied to the fragment + using BiasGradEpilogue = typename DefaultGemm::Epilogue; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename DefaultGemm::Mma::Operator::IteratorC, + typename DefaultGemm::Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradQ { + // grad_q <- tmp @ k_j + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + struct MatmulGradK { + // grad_k <- tmp.transpose(-2, -1) @ q_i + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmemN = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulQK::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + using DefaultMmaFromSmemT = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kM, // kMaxK + WarpIteratorA, + false, // kScaleOperandA + kPreload>; // kTransposeA + using DefaultMmaFromSmem = typename cutlass::platform::conditional< + DefaultMmaFromSmemT::kIsTransposedA, + DefaultMmaFromSmemT, + DefaultMmaFromSmemN>::type; + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + static constexpr bool kEnableSplitKeys = kEnableSplitKeys_; + + static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys || + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradK = !kOutputInRF && + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradV = !kOutputInRF && + !cutlass::platform::is_same::value; + + struct GradQTempStorage { + int32_t lock; + int32_t counter; + int32_t pad[2]; // pad to 128bits + output_accum_t buffer[MatmulGradQ::AccumTileGmem::kElementsStored]; + }; + + struct Params { + // Input tensors + scalar_t* query_ptr = nullptr; // [Mq, nH, K] + scalar_t* key_ptr = nullptr; // [Mk, nH, K] + scalar_t* value_ptr = nullptr; // [Mk, nH, Kv] + scalar_t* bias_ptr = nullptr; + lse_scalar_t* logsumexp_ptr = nullptr; // [nH, Mq] + scalar_t* output_ptr = nullptr; // [Mq, nH, Kv] + scalar_t* grad_output_ptr = nullptr; // [Mq, nH, Kv] + accum_t* delta_ptr = nullptr; // [nH, Mq] + int32_t* cu_seqlens_q_ptr = nullptr; + int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + output_t* grad_query_ptr = nullptr; // [Mq, nH, K] + output_t* grad_key_ptr = nullptr; // [Mk, nH, K] + output_t* grad_value_ptr = nullptr; // [Mk, nH, Kv] + output_t* grad_bias_ptr = nullptr; + + // Accumulators + output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] + output_accum_t* workspace_gv = + nullptr; // (will be calculated by the kernel) + GradQTempStorage* workspace_gq = + nullptr; // (will be calculated by the kernel) + + // Scale + accum_t scale = 1.0f; + + // Dimensions/strides + int32_t head_dim = -1; + int32_t head_dim_value = -1; + int32_t num_queries = -1; + int32_t num_keys = -1; + int32_t num_heads = -1; + uint8_t custom_mask_type = NoCustomMask; + + int32_t q_strideM = -1; + int32_t k_strideM = -1; + int32_t v_strideM = -1; + int32_t bias_strideM = 0; + int32_t gO_strideM = -1; + int32_t gB_strideM = -1; + int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise + +#ifdef HAS_PYTORCH + // dropout + at::PhiloxCudaState rng_engine_inputs = {0, 0}; +#endif + // RNG sequence offset based on batch_id and head_id + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { + return head_dim_value * num_heads; + } + CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gK_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gV_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH = -1; + int32_t q_strideH = -1; + int32_t k_strideH = -1; + int32_t v_strideH = -1; + int64_t bias_strideH = 0; + int64_t o_strideB = -1; + int64_t q_strideB = -1; + int64_t k_strideB = -1; + int64_t v_strideB = -1; + int64_t bias_strideB = 0; + int64_t lse_strideB = -1; + int64_t lse_strideH = -1; + int64_t delta_strideB = -1; + int64_t delta_strideH = -1; + int32_t num_batches = -1; + int16_t num_splits_key = 1; // We use `gridDim.x` inside kernel + + int64_t gO_strideB = 0; + int64_t gQ_strideB = 0; + int64_t gK_strideB = 0; + int64_t gV_strideB = 0; + int64_t gB_strideB = 0; + int64_t gO_strideH = 0; + int64_t gQ_strideH = 0; + int64_t gK_strideH = 0; + int64_t gV_strideH = 0; + int64_t gB_strideH = 0; + + CUTLASS_DEVICE int16_t num_splits_key_device() const { + return kEnableSplitKeys ? gridDim.x : 1; + } + CUTLASS_DEVICE int16_t split_key_device() const { + return kEnableSplitKeys ? blockIdx.x : 0; + } + + CUTLASS_DEVICE bool advance_to_block() { + int64_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { + assert(workspace_size() == 0 || workspace != nullptr); + + workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); + workspace = warp_uniform(workspace); + workspace_gv = workspace + workspace_elements_gk(); + workspace_gq = + (GradQTempStorage*)(workspace_gv + workspace_elements_gv()); + if (kEnableSplitKeys) { + workspace_gv += workspace_elements_gv() * split_key_device() / + num_splits_key_device(); + workspace += workspace_elements_gk() * split_key_device() / + num_splits_key_device(); + } + } else { + workspace = nullptr; + } + + // Advance pointers that depend on the total concatenated + // number of queries, as `num_queries` is modified in the block + // below + dropout_batch_head_rng_offset = + batch_id * (num_heads * num_queries * num_keys) + + head_id * (num_queries * num_keys); + logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; + + if (cu_seqlens_q_ptr != nullptr) { + assert(cu_seqlens_k_ptr != nullptr); + cu_seqlens_q_ptr += batch_id; + cu_seqlens_k_ptr += batch_id; + int32_t q_start = cu_seqlens_q_ptr[0]; + int32_t k_start = cu_seqlens_k_ptr[0]; + int64_t q_next_start = cu_seqlens_q_ptr[1]; + int64_t k_next_start = cu_seqlens_k_ptr[1]; + assert(q_next_start - q_start <= num_queries); + assert(k_next_start - k_start <= num_keys); + num_queries = q_next_start - q_start; + num_keys = k_next_start - k_start; + + // Jump manually + batch_id = 0; + + query_ptr += q_start * q_strideM; + key_ptr += k_start * k_strideM; + value_ptr += k_start * v_strideM; + assert(bias_ptr == nullptr); + assert(grad_bias_ptr == nullptr); + output_ptr += q_start * o_strideM(); + grad_output_ptr += q_start * gO_strideM; + delta_ptr += q_start; + + grad_query_ptr += q_start * gQ_strideM(); + grad_key_ptr += k_start * gK_strideM(); + grad_value_ptr += k_start * gV_strideM(); + } + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + if (bias_ptr != nullptr) { + bias_ptr += batch_id * bias_strideB + head_id * bias_strideH; + } + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + if (grad_bias_ptr != nullptr) { + grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH; + } + + // Some values are modified above + // Signal to the compiler that they are the same in all threads + // and can be stored in warp-uniform registers (Sm75+) + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + custom_mask_type = warp_uniform(custom_mask_type); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + bias_ptr = warp_uniform(bias_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); + grad_bias_ptr = warp_uniform(grad_bias_ptr); + +#if 0 + PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f", + int(blockIdx.z), int(blockIdx.y), + float(delta_ptr[0]), + float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]), + float(logsumexp_ptr[0]) + ) +#endif + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3(num_splits_key, num_heads, num_batches); + } + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const { + if (!kNeedsAccumGradK) { + return 0; + } + return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { + if (!kNeedsAccumGradV) { + return 0; + } + return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim_value, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { + if (!kNeedsAccumGradQ) { + return 0; + } + int num_blocks = ceil_div(num_queries, kBlockSizeI); + int num_cols = ceil_div(head_dim, MatmulGradQ::ThreadblockShape::kN); + return num_blocks * num_cols * sizeof(GradQTempStorage) / + sizeof(output_accum_t); + } + CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const { + // Aligned on 128bits + return align_up( + workspace_elements_gk() + workspace_elements_gv() + + workspace_elements_gq(), + int64_t(4)); + } + CUTLASS_HOST_DEVICE int64_t workspace_size() const { + // Returns size of buffer we need to run this kernel + return num_batches * num_heads * workspace_strideBH() * sizeof(float); + } + CUTLASS_HOST_DEVICE bool should_zero_workspace() const { + return num_splits_key > 1; + } + }; + + // shared storage for keeping Zij matrix. not needed if we aren't using + // dropout, in which case we use an empty array to save shared memory + using ZijSharedStorage = typename cutlass::platform::conditional< + kApplyDropout, + typename MatmulQK::AccumulatorSharedStorage, + // dummy shared storage object that takes up no space. + typename cutlass::gemm::threadblock::AccumulatorSharedStorage< +#ifdef _WIN32 + // windows builds throw the error: + // "type containing an unknown-size array is not allowed" + // if we try to make Zij shared storage zero-sized. + // To get around this just make it sized 1 on windows. + typename cutlass::gemm::GemmShape<1, 1, 0>, +#else + typename cutlass::gemm::GemmShape<0, 0, 0>, +#endif + typename MatmulQK::AccumulatorSharedStorage::Element, + typename MatmulQK::AccumulatorSharedStorage::Layout, + typename cutlass::MatrixShape<0, 0>>>::type; + + struct SharedStoragePrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + } persistent; + union { + struct { + // part1 - after Q.K / dV / dO.V + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 4. store Pij. it is needed: + // - in dVj += (Pij.T * Zij) @ dOi + // - in dSij = Pij * (dPij - Di) + // 6. dVj += (Pij.T * Zij) @ dOi + // 10. write to fragment + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 5. store Zij. it is needed in dVj += (Pij.T * Zij) @ dOi + ZijSharedStorage zij; + + union { + // 2. prologue for dVj + // 6. workspace for dVj += (Pij.T * Zij) @ dOi + typename MatmulGradV::Mma::SharedStorage mm_gradV; + // 7. dVj epilogue + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + + // 3. prologue for dPij_dropped + // 8. used in dPij_dropped = dOi @ Vj.T + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + } part1; + + struct { + // part2 - dQ + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload) + union { + // store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + }; + + } part2; + + struct { + // part3 - after last iteration on dQ's epilogue / dK + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::DefaultEpilogue::SharedStorage + gradQ_epilogue_lastIter; + + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + } part3; + + struct { + // part4 - after last iteration on dK's epilogue / preload next K.Q_t + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + + // If we reach end of current key, dump RF->gmem with "final" epilogues + typename MatmulGradK::DefaultEpilogue::SharedStorage + gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage + gradV_epilogue_final; + } part4; + }; + static void print_size() { + // Field size +#define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f))) + + printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue))); + printf(" persistent: %db\n", FSZ(persistent)); + printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k)); + printf(" part1: %db\n", FSZ(part1)); + printf(" bias: %db\n", FSZ(part1.bias)); + printf(" attn_shared_storage: %db\n", FSZ(part1.attn_shared_storage)); + printf(" zij: %db\n", FSZ(part1.zij)); + printf(" mm_gradV: %db\n", FSZ(part1.mm_gradV)); + printf(" gradV_epilogue: %db\n", FSZ(part1.gradV_epilogue)); + printf(" mm_doivj: %db\n", FSZ(part1.mm_doivj)); + printf(" part2: %db\n", FSZ(part2)); + printf(" tmpT_shared_storage: %db\n", FSZ(part2.tmpT_shared_storage)); + printf(" tmp_shared_storage: %db\n", FSZ(part2.tmp_shared_storage)); + printf(" mm_gradK: %db\n", FSZ(part2.mm_gradK)); + printf(" mm_gradQ: %db\n", FSZ(part2.mm_gradQ)); + printf(" gradB_epilogue: %db\n", FSZ(part2.gradB_epilogue)); + printf(" gradQ_epilogue: %db\n", FSZ(part2.gradQ_epilogue)); + printf(" part3: %db\n", FSZ(part3)); + printf(" tmpT_shared_storage: %db\n", FSZ(part3.tmpT_shared_storage)); + printf(" part4: %db\n", FSZ(part4)); + printf(" mm_qk_q: %db\n", FSZ(part4.mm_qk_q)); + printf( + " gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final)); + printf( + " gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final)); + } +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { \ + return INSIDE_STRUCT.FIELDNAME; \ + } + + FIELD(persistent, di) + FIELD(persistent, mm_qk_k) + FIELD(part1, bias) + FIELD(part1, attn_shared_storage) + FIELD(part1, zij) + FIELD(part1, mm_gradV) + FIELD(part1, gradV_epilogue) + FIELD(part1, mm_doivj) + FIELD(part2, mm_gradK) + FIELD(part2, mm_gradQ) + FIELD(part2, gradB_epilogue) + FIELD(part2, gradQ_epilogue) + FIELD(part2, tmp_shared_storage) + FIELD(part3, tmpT_shared_storage) + FIELD(part3, gradQ_epilogue_lastIter) + FIELD(part3, gradK_epilogue) + FIELD(part4, mm_qk_q) + FIELD(part4, gradK_epilogue_final) + FIELD(part4, gradV_epilogue_final) + }; + + struct SharedStorageNoPrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + } persistent; + union { + struct { + // part1 - Q.K matmul + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + } part1; + + struct { + // part2 - compute gradV + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 2. store Pij to shared memory. it is needed: + // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi + // - in next step where it is used in dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 3. store Zij. it is needed in this step, where it is used + // to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij are + // loaded for the computation of dVj. + ZijSharedStorage zij; + + union { + typename MatmulGradV::Mma::SharedStorage mm_gradV; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + } part2; + + struct { + // part3 - DO.V matmul + union { + // first compute dPij = (dOi @ Vj.T) * Zij + // and dSij = Pij * (dPij - Di) + struct { + // (from part2) - Pij for computing dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + // matmul to compute dOiVj + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + }; + // then store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + }; + } part3; + + struct { + // part4 - compute gradQ + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage + gradQ_epilogue_lastIter; + }; + } part4; + + struct { + // part5 - compute gradK + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradK::Mma::SharedStorage mm_gradK; + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + }; + } part5; + + struct { + // part6 - store RF accumulated into gmem + typename MatmulGradK::DefaultEpilogue::SharedStorage + gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage + gradV_epilogue_final; + } part6; + }; + static void print_size() { +#define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f))) + printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue))); + printf(" persistent: %db\n", FIELD_SIZEOF(persistent)); + printf(" part1: %db\n", FIELD_SIZEOF(part1)); + printf(" part2: %db\n", FIELD_SIZEOF(part2)); + printf(" part3: %db\n", FIELD_SIZEOF(part3)); + printf(" part4: %db\n", FIELD_SIZEOF(part4)); + printf(" part5: %db\n", FIELD_SIZEOF(part5)); + printf(" part6: %db\n", FIELD_SIZEOF(part6)); + } +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { \ + return INSIDE_STRUCT.FIELDNAME; \ + } + + FIELD(persistent, di) + FIELD(part1, mm_qk_k) + FIELD(part1, mm_qk_q) + FIELD(part2, bias) + FIELD(part2, attn_shared_storage) + FIELD(part2, zij) + FIELD(part2, mm_gradV) + FIELD(part2, gradV_epilogue) + FIELD(part3, mm_doivj) + FIELD(part3, gradB_epilogue) + FIELD(part4, tmpT_shared_storage) + FIELD(part4, tmp_shared_storage) + FIELD(part4, mm_gradQ) + FIELD(part4, gradQ_epilogue) + FIELD(part4, gradQ_epilogue_lastIter) + FIELD(part5, mm_gradK) + FIELD(part5, gradK_epilogue) + FIELD(part6, gradK_epilogue_final) + FIELD(part6, gradV_epilogue_final) + }; + + using SharedStorage = typename cutlass::platform::conditional< + kPreload, + SharedStoragePrologue, + SharedStorageNoPrologue>::type; + + struct OutputFragments { + typename MatmulGradV::Mma::FragmentC gradV; + typename MatmulGradK::Mma::FragmentC gradK; + + CUTLASS_DEVICE void clear() { + gradV.clear(); + gradK.clear(); + } + }; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment); + XFORMERS_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned"); + XFORMERS_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, + "query is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0, + "key is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0, + "value is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, + "query is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, + "key is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0, + "value is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.q_strideM % kMinimumAlignment == 0, + "query is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.k_strideM % kMinimumAlignment == 0, + "key is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.v_strideM % kMinimumAlignment == 0, + "value is not correctly aligned (strideM)"); + if (p.bias_ptr) { + XFORMERS_CHECK( + p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.bias_strideM % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideM)"); + } + if (p.grad_bias_ptr) { + XFORMERS_CHECK( + p.num_batches <= 1 || p.gB_strideB % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.gB_strideH % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.gB_strideM % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideM)"); + } + XFORMERS_CHECK( + !(p.cu_seqlens_q_ptr && p.bias_ptr), + "CuSeqlen + bias not implemented yet"); + XFORMERS_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "Invalid value for `custom_mask_type`"); + XFORMERS_CHECK( + p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f, + "Invalid value for `dropout_prob`"); + XFORMERS_CHECK( + kApplyDropout || p.dropout_prob == 0.0f, + "Set `kApplyDropout`=True to support `dropout_prob > 0`"); + XFORMERS_CHECK(p.head_dim > 0, "Invalid value for `head_dim`"); + XFORMERS_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`"); + XFORMERS_CHECK(p.num_queries > 0, "Invalid value for `num_queries`"); + XFORMERS_CHECK(p.num_keys > 0, "Invalid value for `num_keys`"); + XFORMERS_CHECK(p.num_heads > 0, "Invalid value for `num_heads`"); + XFORMERS_CHECK(p.num_batches > 0, "Invalid value for `num_batches`"); + XFORMERS_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`"); + XFORMERS_CHECK( + p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`"); + if (kKeysQueriesAlignedToBlockSize) { + XFORMERS_CHECK( + p.cu_seqlens_k_ptr == nullptr, + "This kernel does not support cu_seqlen"); + XFORMERS_CHECK( + p.cu_seqlens_q_ptr == nullptr, + "This kernel does not support cu_seqlen"); + XFORMERS_CHECK( + p.num_queries % kBlockSizeI == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + XFORMERS_CHECK( + p.num_keys % kBlockSizeJ == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + } + XFORMERS_CHECK( + kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled"); + XFORMERS_CHECK( + p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); + XFORMERS_CHECK( + p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ), + "Invalid `num_splits_key` (too large)"); + return true; + } + + static CUTLASS_DEVICE void attention_kernel(Params p) { + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + + uint16_t thread_id = threadIdx.x; + uint8_t warp_id = warp_uniform(thread_id / 32); + uint8_t lane_id = thread_id % 32; + + int32_t key_start = p.split_key_device() * kBlockSizeJ; + if (key_start >= p.num_keys) { + return; + } + if (kPrologueQK) { + int32_t query_start = getQueryStart(p, key_start); + prologueQkNextIteration( + shared_storage, p, query_start, key_start, warp_id, lane_id); + } + + // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr` + if (kKernelComputesDelta) { + constexpr int kOptimalElements = + 128 / cutlass::sizeof_bits::value; + if (p.head_dim_value % kOptimalElements == 0) { + for (int query_start = 0; query_start < p.num_queries; + query_start += kBlockSizeI) { + computeDelta(p, query_start, warp_id, lane_id); + } + } else { + for (int query_start = 0; query_start < p.num_queries; + query_start += kBlockSizeI) { + computeDelta<1>(p, query_start, warp_id, lane_id); + } + } + __syncthreads(); + } + + OutputFragments output_frags; + + curandStatePhilox4_32_10_t rng_state_init; +#ifdef HAS_PYTORCH + if (kApplyDropout) { + auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + p.dropout_batch_head_rng_offset, + &rng_state_init); + } +#endif + CUTLASS_PRAGMA_UNROLL + for (; key_start < p.num_keys; + key_start += p.num_splits_key_device() * kBlockSizeJ) { + output_frags.clear(); + + CUTLASS_PRAGMA_UNROLL + for (int32_t query_start_shifted = getQueryStart(p, key_start); + query_start_shifted < getQueryStartShift(p) + getQueryEnd(p); + query_start_shifted += kBlockSizeI) { + // This line here + // vvvvvvvvvvvvvv + warp_id = warp_uniform(warp_id); + // ^^^^^^^^^^^^^^ + // ... makes everything use less RF and be 10% faster. Why? + // I don't know. My theory is that it forces `nvcc` to + // re-compute indices, offsets etc... and not keep them + // from the previous iteration, which prevents MASSIVE + // register spilling. + + int32_t query_start = query_start_shifted; + if (query_start >= p.num_queries) { + query_start = query_start % getQueryEnd(p); + } + + processBlockIJ( + shared_storage, + output_frags, + p, + query_start, + key_start, + rng_state_init, + warp_id, + lane_id); + } + if (kOutputInRF) { + writeFragsToGmem( + shared_storage, output_frags, p, key_start, warp_id, lane_id); + } else if (getQueryStart(p, key_start) >= p.num_queries) { + zfillGradKV( + p, key_start, warp_id, lane_id); + } + __syncthreads(); + } + } + + template + static CUTLASS_DEVICE void zfillGradKV( + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + constexpr int kThreadsPerKey = 8; + constexpr int kParallelKeys = kNumThreads / kThreadsPerKey; + static_assert(kBlockSizeJ % kParallelKeys == 0, ""); + // This function is not really optimized, but should rarely be used + // It's only used when some keys are "useless" and don't attend to + // any query, due to causal masking + + int thread_id = 32 * warp_id + lane_id; + int k_shift = lane_id % kThreadsPerKey; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) { + int key = key_start + j + (thread_id / kThreadsPerKey); + if (!skipBoundsChecks && key >= p.num_keys) { + continue; + } + auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM(); + auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM(); + + for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) { + gv_ptr[k] = scalar_t(0); + } + for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { + gk_ptr[k] = scalar_t(0); + } + } + } + + template + static CUTLASS_DEVICE void processBlockIJ( + SharedStorage& shared_storage, + OutputFragments& output_frags, + Params& p, + int32_t query_start, + int32_t key_start, + const curandStatePhilox4_32_10_t& curand_state_init, + uint8_t warp_id, + uint8_t lane_id) { + cutlass::Array + dropout_keep_mask_doivj; + dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1}); + const float dropout_scale = + kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f; + + cutlass::MatrixCoord no_offset{0, 0}; + accum_t scale = p.scale; + int16_t thread_id = 32 * warp_id + lane_id; + + auto rematerializeThreadIds = [&]() { + // Prevents `nvcc` from keeping values deduced from + // `thread_id`, `warp_id`, ... in RF - to reduce register pressure + warp_id = warp_uniform(thread_id / 32); + lane_id = thread_id % 32; + thread_id = 32 * warp_id + lane_id; + }; + + bool isFirstQuery = (query_start == getQueryStart(p, key_start)); + int32_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + bool isLastQuery = next_key != key_start; + + accum_t di_rf = accum_t(0); + if (thread_id < kBlockSizeI) { + if (query_start + thread_id < p.num_queries) { + di_rf = p.delta_ptr[query_start + thread_id]; + } + shared_storage.di()[thread_id] = di_rf; + } + + int32_t num_queries_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kN + : warp_uniform(cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start)); + int32_t num_keys_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : warp_uniform(cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start)); + + auto prologueGradV = [&](int col) { + typename MatmulGradV::Mma::IteratorB iterator_dO( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + MatmulGradV::Mma::prologue( + shared_storage.mm_gradV(), + iterator_dO, + thread_id, + num_queries_in_block); + }; + auto prologueGradQ = [&](int col) { + typename MatmulGradQ::Mma::IteratorB iterator_K( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, + {num_keys_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradQ::Mma::prologue( + shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block); + }; + auto prologueGradK = [&](int col) { + typename MatmulGradK::Mma::IteratorB iterator_Q( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, + {num_queries_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradK::Mma::prologue( + shared_storage.mm_gradK(), + iterator_Q, + thread_id, + num_queries_in_block); + }; + auto prologueDOV = [&]() { + typename MatmulDOIVJ::Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + typename MatmulDOIVJ::Mma::IteratorB iterator_B( + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + MatmulDOIVJ::Mma::prologue( + shared_storage.mm_doivj(), + iterator_A, + iterator_B, + thread_id, + p.head_dim_value); + }; + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulQK + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulQK::Mma; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + num_queries_in_block, + p.head_dim // k + ); + + // k_j + typename Mma::IteratorA iterator_A( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, + {problem_size.m(), problem_size.k()}, + thread_id, + no_offset); + + // q_i.transpose(-2, -1) + typename Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + Mma mma( + shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + thread_id, + warp_id, + lane_id); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma.set_prologue_done(kPrologueQK); + mma.set_zero_outside_bounds(!skipBoundsChecks); + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + accum = cutlass::multiplies()(scale, accum); + + // Epilogue: add LSE + exp and store that to our shared memory buffer + // shmem <- (matmul_result - + // logsumexp[i_start:i_end].unsqueeze(1)).exp() + int warp_idx_mn_0 = + warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, + warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + + // apply bias if applicable + if (p.bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MatmulQK::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + p.bias_ptr + query_start * p.bias_strideM + key_start, + {num_queries_in_block, num_keys_in_block}, + thread_id); + cutlass::TensorRef bias_tensor_ref( + shared_storage.bias().data(), + cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM)); + typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id); + MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, where Pij is in register fragment and Bij is in shmem + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_n) {}, + [&](int accum_m, int accum_n, int idx) { + // remember we are transposed + accum[idx] += bias_tensor_ref.at({accum_n, accum_m}); + }, + [&](int accum_n) {}); + } + + // Apply mask + if (p.custom_mask_type == CausalFromTopLeft || + p.custom_mask_type == CausalFromBottomRight) { + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + int shift = query_start - key_start; + if (p.custom_mask_type == CausalFromBottomRight) { + shift += p.num_keys - p.num_queries; + } + // current_key = key_start + accum_m + // current_query = query_start + accum_n + // mask if: `current_key > current_query` + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m > accum_n + shift) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + + __syncthreads(); + if (kPrologueGV) { + prologueGradV(0); + } + if (kPrologueDOV) { + prologueDOV(); + } + + MatmulQK::B2bGemm::accumApplyLSEToSmem( + shared_storage.attn_shared_storage(), + accum, + p.logsumexp_ptr + query_start, + problem_size.n(), + thread_id, + warp_id, + lane_id, + output_tile_coords); +#if 0 + auto accum_ref_attnT = shared_storage.attn_shared_storage().accum_ref(); + PRINT_TENSOR4x4_T0_L0("attn_T", accum_ref_attnT); +#endif + + // if we are using dropout, compute Zij, writing it to shared memory. + // each element of Zij is: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + if (kApplyDropout) { + auto zij = shared_storage.zij().accum_ref(); + // each thread generates a contiguous sequence of elements in Zij, all + // in the same row. the reason they have to come from the same row is + // that sampling random numbers from a contiguous random number sequence + // is much more efficient than jumping around, and the linear offset of + // each element of Z (the global matrix) maps to an offset in a random + // number sequence. for Z, the end of a row and the beginning of the + // next have adjacent offsets, but for Zij (tile of global matrix), this + // is not necessarily the case. + // We must fill the entire `zij` shmem with values (even out of bounds + // on the K-dimension) otherwise we can get NaNs during the GEMM + const int kQueriesPerBlock = kBlockSizeI; + const int threads_per_row = cutlass::fast_min( + int32_t(kNumThreads / kQueriesPerBlock), num_keys_in_block); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(num_keys_in_block, threads_per_row), 4); + + const int thread_i = thread_id / threads_per_row; + const int thread_start_j = + (thread_id % threads_per_row) * elts_per_thread; + + if (thread_i < kQueriesPerBlock && thread_start_j < num_keys_in_block) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + (query_start + thread_i) * p.num_keys + + (key_start + thread_start_j), + &curand_state); + + // generate elements of Zij, 4 elements at a time + for (int zij_start_col_idx = thread_start_j; zij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + num_keys_in_block); + zij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + // we'll write Zij transposed since attention is also transposed + // during the matmul to compute dV. + zij.at({zij_start_col_idx + quad_idx /*k*/, thread_i /*q*/}) = + (&rand_uniform_quad.x)[quad_idx] > p.dropout_prob + ? scalar_t(dropout_scale) + : scalar_t(0); + } + } + } + __syncthreads(); +#if 0 + PRINT_TENSOR4x4_T0_L0("zij", zij); + PRINT_TENSOR4x4_T0_L0_START("zij", zij, kBlockSizeJ - 4, kBlockSizeI - 4); +#endif + + // Save mask for later DOIVJ matmul + + int warp_idx_mn_0 = warp_id % + (MatmulDOIVJ::Mma::Base::WarpCount::kM * + MatmulDOIVJ::Mma::Base::WarpCount::kN); + auto output_tile_coords_doivj = cutlass::MatrixCoord{ + warp_idx_mn_0 % MatmulDOIVJ::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MatmulDOIVJ::Mma::Base::WarpCount::kM}; + auto lane_offset = MatmulDOIVJ::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords_doivj); + MatmulDOIVJ::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m /*q*/, int accum_n /*k*/, int idx) { + if (zij.at({accum_n, accum_m}) == scalar_t(0)) { + dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0}; + } + }, + [&](int accum_m) {}); + } + __syncthreads(); + } + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradV matmul + // + // grad_v[j_start:j_end] += attn_T @ do_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + constexpr bool kSingleIterationGradV = + kMaxK <= MatmulGradV::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); + col += MatmulGradV::ThreadblockShape::kN) { + using Mma = typename MatmulGradV::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, p.head_dim_value - col, num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradV::OutputTileIterator( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, + {num_keys_in_block, p.head_dim_value - col}, + thread_id); + }; + typename Mma::IteratorB iterator_B( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + + // if dropout: dVj += (Pij.T * Zij) @ dOi + // otherwise: dVj += Pij.T @ dOi + Mma mma( + // operand A: Pij.T + shared_storage.attn_shared_storage().accum_ref(), + // operand A_scale Zij.T: + // if we're using dropout, operand A is Pij_dropped.T = Pij.T * Zij.T + // which is computed on the fly as fragments of Pij.T are loaded in + shared_storage.zij().accum_ref(), + // operand B: dOi - which was loaded into shared memory previously + // when we computed dVj + shared_storage.mm_gradV().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradV::ThreadblockShape::kN; + AccumTileGmem gmem_tile{ + p.workspace_gv + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradV) { + output_frags.gradV.clear(); + } else { + gmem_tile.load(output_frags.gradV, thread_id); + } + } + mma.set_prologue_done(kPrologueGV); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, + output_frags.gradV, + iterator_B, + output_frags.gradV); + __syncthreads(); + if (kPrologueGV && !kSingleIterationGradV && + col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) { + prologueGradV(col + MatmulGradV::ThreadblockShape::kN); + } + + if (!kOutputInRF) { + if (kNeedsAccumGradV && !isLastQuery) { + gmem_tile.store(output_frags.gradV, thread_id); + } else { + accumulateInGmem( + shared_storage.gradV_epilogue(), + output_frags.gradV, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradV, + warp_id, + lane_id); + } + } + } + __syncthreads(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulDOIVJ + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulDOIVJ::Mma; + // do_i + typename Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + + // v_j.transpose(-2, -1) + typename Mma::IteratorB iterator_B( + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + + Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id); + mma.set_prologue_done(kPrologueDOV); + mma.set_zero_outside_bounds(!skipBoundsChecks); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + if (kPrologueGQ) { + prologueGradQ(0); + } + if (kPrologueGK) { + prologueGradK(0); + } + + int warp_idx_mn_0 = + warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, + warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + // TODO: This must be terribly inefficient. There must be a better way + // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem] + // attn_shared_storage [smem] <- tmp.T + // tmp_shared_storage [smem] <- tmp + { + using LambdaIterator = typename MatmulDOIVJ::AccumLambdaIterator; + auto lane_offset = LambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + // if dropout was used, compute dPij = dPij_dropped * Zij + if (kApplyDropout) { + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (dropout_keep_mask_doivj[idx].get()) { + accum[idx] *= dropout_scale; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) {}); + } + + auto attn_T = shared_storage.attn_shared_storage().accum_ref(); +#if 0 + PRINT_B0_T0("doivj_dropped"); + print_warp_accum(accum, lane_offset, 4, 4); + PRINT_TENSOR4x4_T0_L0("attn_T", attn_T) +#endif + accum_t current_di; + // dSij = (dPij - Di) * Pij + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + // TODO: Otherwise we can get nans as we + // might have infs here (only seen on f16 tho) + if (skipBoundsChecks || + (accum_m < num_queries_in_block && + accum_n < num_keys_in_block)) { + accum_t attn = attn_T.at({accum_n, accum_m}); + accum[idx] = (accum[idx] - current_di) * attn; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) { + + }); + + // store bias gradient tile dBij to global memory, + // where dBij = dSij = Pij * (dPij - Di) + if (p.grad_bias_ptr != nullptr) { + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator + output_iter( + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator:: + Params{p.gB_strideM}, + // grad_bias_ptr is offset to point at beginning of + // matrix of shape (queries, keys) for a given + // (batch_id, head_id) the pointer arithmetic here produces + // a pointer to the start of the current tile within that + // matrix + p.grad_bias_ptr + query_start * p.gB_strideM + key_start, + {num_queries_in_block, num_keys_in_block}, + thread_id); + + // no-op epilogue operator - just casting and storing contents of + // accum to global memory + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op( + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp::Params{1, 1}); + typename MatmulDOIVJ::BiasGradEpilogue epilogue( + shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id); + epilogue(output_op, output_iter, accum, output_iter); + } + + accum = accum * scale; + +#if 0 + PRINT_B0_T0("(doivj - di) * attn * scale"); + print_warp_accum(accum, lane_offset, 4, 4); +#endif + + __syncthreads(); + if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { + auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); + // attn <- attn_T.T + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]); + }, + [&](int accum_m) {}); + } + } + + MatmulDOIVJ::B2bGemm::accumToSmem( + shared_storage.tmp_shared_storage(), + accum, + lane_id, + output_tile_coords); + __syncthreads(); + } + // Force `nvcc` to recompute values that depend on the variables just below + // to use less RF and prevent some spilling + p.head_dim = warp_uniform(p.head_dim); + p.k_strideM = warp_uniform(p.k_strideM); + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradQ matmul + // + // grad_q[i_start:i_end] += tmp @ k_j + ///////////////////////////////////////////////////////////////////////////////////////////////// + // Skip the loop & associated branches if we know at compile time the number + // of iterations + constexpr bool kSingleIterationGradQ = + kMaxK <= MatmulGradQ::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim); + col += MatmulGradQ::ThreadblockShape::kN) { + using Mma = typename MatmulGradQ::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_queries_in_block, + false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col, + num_keys_in_block); + + // k_j + typename Mma::IteratorB iterator_B( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto a = shared_storage.tmp_shared_storage().accum_ref(); + Mma mma( + // operand A: dSij + shared_storage.tmp_shared_storage().accum_ref(), + // operand B: Kj + shared_storage.mm_gradQ().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + typename Mma::FragmentC accum; + + int col_id = col / MatmulGradQ::ThreadblockShape::kN; + int num_cols = kSingleIterationGradQ + ? 1 + : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN); + int storage_id = (col_id + query_start / kBlockSizeI * num_cols); + + if (p.num_splits_key_device() > 1) { + AtomicLock::acquire( + &p.workspace_gq[storage_id].lock, + p.split_key_device() + 1, + thread_id); + // Make sure we can see other block's output + __threadfence(); + } + + AccumTileGmem gmem_tile{&p.workspace_gq[storage_id].buffer[0]}; + if (!kNeedsAccumGradQ || + (p.num_splits_key_device() == 1 && key_start == 0)) { + // if we know we are the first to access it, we know it's only zeros. + // Avoids a load from gmem (and gmem init as well) + accum.clear(); + } else { + gmem_tile.load(accum, thread_id); + } + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + mma.set_prologue_done(kPrologueGQ); + mma(gemm_k_iterations, accum, iterator_B, accum); + __syncthreads(); + bool isLastColumn = kSingleIterationGradQ || + (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim); + if (kPrologueGQ && !isLastColumn) { + prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN); + } + + bool isLast = [&]() { + int32_t next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + if (p.num_keys <= next_key) { + return true; + } + if (query_start < getSmallestQueryForKey(p, next_key)) { + return true; + } + return false; + }(); + // Output results + if (p.num_splits_key_device() > 1) { + int32_t numAddsSoFar = -1; + if (isLast && thread_id == 0) { + numAddsSoFar = atomicAdd(&p.workspace_gq[storage_id].counter, 1) + + 1; // `atomicAdd` returns the old value + } + isLast = __syncthreads_or( + numAddsSoFar == getNumParallelBlocksForQuery(p, query_start)); + assert(numAddsSoFar <= getNumParallelBlocksForQuery(p, query_start)); + } + if (kNeedsAccumGradQ && !isLast) { + gmem_tile.store(accum, thread_id); + if (p.num_splits_key_device() > 1) { + // Make sure everyone wrote before we release the lock + __threadfence(); + __syncthreads(); + AtomicLock::release(&p.workspace_gq[storage_id].lock, thread_id); + } + } else { + // NOTE: We're not releasing the lock because no one is expected + // to come after us (we're the last one to write) + typename MatmulGradQ::OutputTileIterator output_it( + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, + {problem_size.m(), problem_size.n()}, + thread_id); + bool storage_contains_zeros = kNeedsAccumGradQ || key_start == 0 || + (p.num_splits_key_device() > 1); + accumulateInGmem( + isLastColumn ? shared_storage.gradQ_epilogue_lastIter() + : shared_storage.gradQ_epilogue(), + accum, + output_it, + storage_contains_zeros, + warp_id, + lane_id); + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradK matmul + // + // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + rematerializeThreadIds(); + + constexpr bool kSingleIterationGradK = + kMaxK <= MatmulGradK::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim); + col += MatmulGradK::ThreadblockShape::kN) { + using Mma = typename MatmulGradK::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col, + num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradK::OutputTileIterator( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, + thread_id); + }; + + // q_i + typename Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); }; + auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); }; + // this is basically: + // opA = kIsTransposedA ? getTmp() : getTmpT(); + bool constexpr kIsTransposedA = + MatmulGradK::DefaultMmaFromSmem::kIsTransposedA; + auto& opA = *call_conditional< + kIsTransposedA, + decltype(getTmp), + decltype(getTmpT)>::apply(getTmp, getTmpT, 0); + Mma mma( + // operand A: dSij.T + opA.accum_ref(), + // operand B: Qi + shared_storage.mm_gradK().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradK::ThreadblockShape::kN; + AccumTileGmem gmem_tile{ + p.workspace + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradK) { + output_frags.gradK.clear(); + } else { + gmem_tile.load(output_frags.gradK, thread_id); + } + } + mma.set_prologue_done(kPrologueGK); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, + output_frags.gradK, + iterator_B, + output_frags.gradK); + __syncthreads(); + bool isLastColumn = kSingleIterationGradK || + col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; + if (kPrologueGK && !isLastColumn) { + prologueGradK(col + MatmulGradK::ThreadblockShape::kN); + } + + if (kPrologueQK && isLastColumn) { + int32_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + DISPATCH_BOOL( + next_key != key_start, kForceReloadK, ([&]() { + prologueQkNextIteration( + shared_storage, p, next_query, next_key, warp_id, lane_id); + })); + } + + // Output results + if (!kOutputInRF) { + if (kNeedsAccumGradK && !isLastQuery) { + gmem_tile.store(output_frags.gradK, thread_id); + } else { + accumulateInGmem( + isLastColumn ? shared_storage.gradK_epilogue_final() + : shared_storage.gradK_epilogue(), + output_frags.gradK, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradK, + warp_id, + lane_id); + __syncthreads(); + } + } + } + } + + static CUTLASS_DEVICE int32_t getQueryStartShift(Params const& p) { + if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) { + return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p); + } + return 0; + } + + // Iteration order logic + static CUTLASS_DEVICE int32_t + getQueryStart(Params const& p, int32_t key_start) { + return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p); + }; + static CUTLASS_DEVICE int32_t getQueryEnd(Params const& p) { + return align_up(p.num_queries, kBlockSizeI); + }; + + static CUTLASS_DEVICE int32_t + getSmallestQueryForKey(Params const& p, int32_t key_start) { + if (p.custom_mask_type == CausalFromTopLeft) { + return (key_start / kBlockSizeI) * kBlockSizeI; + } else if (p.custom_mask_type == CausalFromBottomRight) { + int first_query = + cutlass::fast_max(0, key_start - p.num_keys + p.num_queries); + return (first_query / kBlockSizeI) * kBlockSizeI; + } + return 0; + }; + + // Returns how many kernel blocks will write to a given block in `grad_query` + // This is usually equal to the number of key splits, but can be different + // for instance in the causal case, or varying seqlen + static CUTLASS_DEVICE int32_t + getNumParallelBlocksForQuery(Params const& p, int32_t query_start) { + int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ); + if (p.custom_mask_type == CausalFromTopLeft) { + int32_t last_key_for_block = query_start + kBlockSizeI - 1; + last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); + num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); + } else if (p.custom_mask_type == CausalFromBottomRight) { + int32_t last_key_for_block = + query_start + (kBlockSizeI - 1) + (1 + p.num_keys - p.num_queries); + last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); + num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); + } + return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks); + }; + + // Returns the next block to process + static CUTLASS_DEVICE void incrIteration( + Params const& p, + int32_t query_start, + int32_t key_start, + int32_t& next_query, + int32_t& next_key) { + next_query = query_start + kBlockSizeI; + next_key = key_start; + auto query_shift = getQueryStartShift(p); + // Wrap around + if (query_shift) { + if (next_query >= p.num_queries) { + next_query = getSmallestQueryForKey(p, key_start); + return; + } else if (query_start < query_shift && query_shift <= next_query) { + // jump to next key + } else { + return; + } + } else { + if (next_query < p.num_queries) { + return; + } + // jump to next key + } + // Next key + next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + next_query = getQueryStart(p, next_key); + } + + template + static CUTLASS_DEVICE void prologueQkNextIteration( + SharedStorage& shared_storage, + Params const& p, + int32_t query_start, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + if (query_start >= p.num_queries || key_start >= p.num_keys) { + return; + } + + static constexpr bool kReloadK = + kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; + int thread_id = 32 * warp_id + lane_id; + typename MatmulQK::Mma::IteratorA iterator_A( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, + {p.num_keys - key_start, p.head_dim}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + typename MatmulQK::Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, + {p.head_dim, p.num_queries - query_start}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + MatmulQK::Mma::template prologue( + shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + iterator_A, + iterator_B, + thread_id, + p.head_dim); + } + + template + static CUTLASS_DEVICE void writeFragsToGmem( + SharedStorage& shared_storage, + OutputFragments& output_frags, + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + uint16_t thread_id = 32 * warp_id + lane_id; + int32_t num_keys_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); + typename MatmulGradV::OutputTileIterator outputV_it( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), + {num_keys_in_block, p.head_dim_value}, + thread_id); + + accumulateInGmem( + shared_storage.gradV_epilogue_final(), + output_frags.gradV, + outputV_it, + true, + warp_id, + lane_id); + + typename MatmulGradK::OutputTileIterator outputK_it( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, + thread_id); + accumulateInGmem( + shared_storage.gradK_epilogue_final(), + output_frags.gradK, + outputK_it, + true, + warp_id, + lane_id); + } + + template + static CUTLASS_DEVICE void accumulateInGmem( + typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, + typename MatmulT::Mma::FragmentC const& accum, + typename MatmulT::OutputTileIterator output_it, + bool first, + uint8_t warp_id, + uint8_t lane_id) { + using DefaultEpilogue = typename MatmulT::DefaultEpilogue; + using DefaultOutputOp = typename MatmulT::DefaultOutputOp; + using Mma = typename MatmulT::Mma; + int thread_id = 32 * warp_id + lane_id; + DISPATCH_BOOL( + first, kIsFirst, ([&]() { + static constexpr auto ScaleType = kIsFirst::value + ? cutlass::epilogue::thread::ScaleType::Nothing + : cutlass::epilogue::thread::ScaleType::NoBetaScaling; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + typename DefaultOutputOp::ElementOutput, + DefaultOutputOp::kCount, + typename DefaultOutputOp::ElementAccumulator, + typename DefaultOutputOp::ElementCompute, + ScaleType>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MatmulT::OutputTileIterator, + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true // IterationsUnroll + >; + EpilogueOutputOp rescale({1, 1}); + Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id); + epilogue(rescale, output_it, accum, output_it); + })); + } + + template + static CUTLASS_DEVICE void computeDelta( + Params const& p, + int32_t query_start, + uint8_t warp_id, + uint8_t lane_id) { + // Each thread computes one value for Delta + // Depending on warp configuration, we might have multiple + // threads of the same warp working on the same row + using AccessType = cutlass::Array; + static_assert(kNumThreads >= kBlockSizeI, ""); + static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI; + int16_t thread_id = 32 * warp_id + lane_id; + + int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine); + int16_t laneRow = thread_id / kNumThreadsPerLine; + bool rowPred = (query_start + laneRow) < p.num_queries; + bool pred = rowPred; + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = + reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + + laneFirstCol); + const AccessType* __restrict__ output_ptr = + reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + + laneFirstCol); + + static constexpr int64_t kMaxIters = + kMaxK / (kElementsPerAccess * kNumThreadsPerLine); + constexpr int kPipelineStages = 2; + accum_t delta_value = accum_t(0); + using GlobalLoad = + cutlass::arch::global_load; + AccessType frag_grad_output[kPipelineStages]; + AccessType frag_output[kPipelineStages]; + + auto loadAndIncrement = [&](int ld_pos, bool is_valid) { + frag_grad_output[ld_pos].clear(); + frag_output[ld_pos].clear(); + GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid); + GlobalLoad(frag_output[ld_pos], output_ptr, is_valid); + grad_output_ptr += kNumThreadsPerLine; + output_ptr += kNumThreadsPerLine; + }; + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kPipelineStages - 1; ++iter) { + int ld_pos = iter % kPipelineStages; + pred = pred && + (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) < + p.head_dim_value; + loadAndIncrement(ld_pos, pred); + } + auto columnIteration = [&](int iter) { + // Load for next iter + int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages; + pred = pred && + (laneFirstCol + + (iter + kPipelineStages - 1) * kElementsPerAccess * + kNumThreadsPerLine) < p.head_dim_value; + loadAndIncrement(ld_pos, pred); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < AccessType::kElements; ++i) { + delta_value += accum_t(frag_output[iter % kPipelineStages][i]) * + accum_t(frag_grad_output[iter % kPipelineStages][i]); + } + }; + + // If we have a small lower-bound for K, we can unroll the loop + if (kMaxK <= 256) { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kMaxIters; ++iter) { + columnIteration(iter); + } + } else { + int num_iters = + ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) * + (kElementsPerAccess * kNumThreadsPerLine); + for (int iter = 0; iter < num_iters; ++iter) { + columnIteration(iter); + } + } + + // Reduce between workers + static_assert( + kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || + kNumThreadsPerLine == 4, + ""); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kNumThreadsPerLine; i *= 2) { + delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i); + } + + // Store in gmem + if (rowPred) { + p.delta_ptr[query_start + laneRow] = delta_value; + } + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched(typename AK::Params params); diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..ed4e16775951a0062fdc7e2fa435bae13610b470 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h @@ -0,0 +1,1322 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#ifdef HAS_PYTORCH +#include +#include +#endif + +#include +#include +#include +#include + +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "debug_utils.h" +#include "epilogue/epilogue_pipelined.h" +#include "epilogue/epilogue_rescale_output.h" +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" +#include "gemm_kernel_utils.h" +#include "transform/tile_smem_loader.h" + +using namespace gemm_kernel_utils; + +namespace { +template +constexpr int getWarpsPerSmFw() { + return ( + Arch::kMinComputeCapability >= 80 && + !cutlass::platform::is_same::value + ? 16 + : 12); +} +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} // namespace + +// If ToBatchHookType_ is supplied other than this default (which is +// never the case in the xformers library) then the user is +// defining the logic which each block uses to find its data to work on, +// with the advance_to_batch function with the following signature. +// It should return false if there is no work to do for this block. +// In general this will not work with saving for backward due to fixed layout +// for logsumexp and incompatible rngs for dropout, so is likely only useful for +// custom inference. +struct DefaultToBatchHook { + template + CUTLASS_DEVICE static bool advance_to_batch( + Params&, + int64_t& /* q_start */, + int64_t& /* k_start */) { + return true; + } +}; + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock_, + int kKeysPerBlock_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsDropout_ = true, + bool kSupportsBias_ = true, + typename ToBatchHookType_ = DefaultToBatchHook> +struct AttentionKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + + using scalar_t = scalar_t_; + using accum_t = float; + using lse_scalar_t = float; + using output_t = scalar_t; + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + static constexpr bool kSupportsDropout = kSupportsDropout_; + static constexpr bool kSupportsBias = kSupportsBias_; + static constexpr int kKeysPerBlock = kKeysPerBlock_; + static constexpr int kQueriesPerBlock = kQueriesPerBlock_; + static constexpr int kMaxK = kMaxK_; + static constexpr bool kIsAligned = isAligned_; + static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock; + static constexpr int32_t kAlignLSE = 32; // block size of backward + static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; + static constexpr bool kPreloadV = + ArchTag::kMinComputeCapability >= 80 && kIsHalf; + static constexpr bool kKeepOutputInRF = kSingleValueIteration; + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + static_assert(kQueriesPerBlock % 32 == 0, ""); + static_assert(kKeysPerBlock % 32 == 0, ""); + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (32 * 32); + static constexpr int kWarpSize = 32; + + // Launch bounds + static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int kMinBlocksPerSm = + getWarpsPerSmFw() / kNumWarpsPerBlock; + + struct Params { + // Input tensors + scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim] + scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim] + scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value] + scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] + int32_t* seqstart_q_ptr = nullptr; + int32_t* seqstart_k_ptr = nullptr; + + int32_t* seqlen_k_ptr = nullptr; + uint32_t causal_diagonal_offset = 0; + + // Output tensors + output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value] + // [num_queries, num_heads, head_dim_value] + output_accum_t* output_accum_ptr = nullptr; + // [num_heads, num_queries] - can be null + lse_scalar_t* logsumexp_ptr = nullptr; + + // Scale + accum_t scale = 0.0; + + // Dimensions/strides + int32_t head_dim = 0; + int32_t head_dim_value = 0; + int32_t num_queries = 0; + int32_t num_keys = 0; + int32_t num_keys_absolute = 0; + + uint8_t custom_mask_type = NoCustomMask; + + int32_t q_strideM = 0; + int32_t k_strideM = 0; + int32_t v_strideM = 0; + int32_t bias_strideM = 0; + + int32_t o_strideM = 0; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int32_t q_strideH = 0; + int32_t k_strideH = 0; + int32_t v_strideH = 0; + int64_t bias_strideH = 0; + + int64_t q_strideB = 0; + int64_t k_strideB = 0; + int64_t v_strideB = 0; + int64_t bias_strideB = 0; + + int32_t num_batches = 0; + int32_t num_heads = 0; + + // dropout + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; +#ifdef HAS_PYTORCH + at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0); +#endif + + // Moves pointers to what we should process + // Returns "false" if there is no work to do + CUTLASS_DEVICE bool advance_to_block() { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + if (kSupportsDropout) { + dropout_batch_head_rng_offset = + batch_id * num_heads * num_queries * num_keys + + head_id * num_queries * num_keys; + } + + int64_t q_start = 0, k_start = 0; + // Advance to current batch - in case of different sequence lengths + constexpr bool kToBatchHook = + !cutlass::platform::is_same:: + value; + if (kToBatchHook) { + // Call out to a custom implementation. + if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) { + return false; + } + } else if (seqstart_q_ptr != nullptr) { + assert(seqstart_k_ptr != nullptr); + seqstart_q_ptr += batch_id; + + q_start = seqstart_q_ptr[0]; + int64_t q_next_start = seqstart_q_ptr[1]; + int64_t k_end; + seqstart_k_ptr += batch_id; + + if (seqlen_k_ptr) { + k_start = seqstart_k_ptr[0]; + k_end = k_start + seqlen_k_ptr[batch_id]; + } else { + k_start = seqstart_k_ptr[0]; + k_end = seqstart_k_ptr[1]; + } + + num_queries = q_next_start - q_start; + num_keys = k_end - k_start; + + if (query_start >= num_queries) { + return false; + } + } else { + query_ptr += batch_id * q_strideB; + key_ptr += batch_id * k_strideB; + value_ptr += batch_id * v_strideB; + output_ptr += int64_t(batch_id * num_queries) * o_strideM; + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(batch_id * num_queries) * (head_dim_value * num_heads); + } + q_start = 0; + k_start = 0; + } + + // Advance to the current batch / head / query_start + query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + head_id * k_strideH; + + value_ptr += k_start * v_strideM + head_id * v_strideH; + output_ptr += + int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value; + + if (kSupportsBias && attn_bias_ptr != nullptr) { + attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); + } + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(q_start + query_start) * (head_dim_value * num_heads) + + head_id * head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += + batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (custom_mask_type == CausalFromBottomRight) { + causal_diagonal_offset = num_keys - num_queries; + } + // We use num_keys_absolute to index into the rng_state + // We need this index to match between forward and backwards + num_keys_absolute = num_keys; + if (custom_mask_type == CausalFromTopLeft || + custom_mask_type == CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + num_keys = cutlass::fast_min( + int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock), + num_keys); + } + + num_queries -= query_start; + num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + q_strideM = q_strideH; + num_queries = num_heads; + num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + custom_mask_type = NoCustomMask; + o_strideM = head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + // Only worth doing if they could have been modified above. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + if (kSupportsBias) { + attn_bias_ptr = warp_uniform(attn_bias_ptr); + } + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + o_strideM = warp_uniform(o_strideM); + custom_mask_type = warp_uniform(custom_mask_type); + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3( + ceil_div(num_queries, (int32_t)kQueriesPerBlock), + num_heads, + num_batches); + } + + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize, kNumWarpsPerBlock, 1); + } + }; + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::ColumnMajor, // LayoutB, + kAlignmentB, + accum_t, + cutlass::layout::RowMajor, // LayoutC, + OpClass, + ArchTag, // ArchTag + ThreadblockShape, // ThreadblockShape + WarpShape, // WarpShape + typename GemmType::InstructionShape, // InstructionShape + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + typename GemmType::Operator // Operator + >::DefaultMma; + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; + using Mma = typename cutlass::platform::conditional< + kSingleValueIteration, + typename MakeCustomMma::Mma, + DefaultThreadblockMma>::type; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + static_assert( + MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * + MmaCore::WarpCount::kK == + kNumWarpsPerBlock, + ""); + + // used for efficient load of bias tile Bij from global to shared memory + using BiasLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + output_accum_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + LayoutB, // LayoutB, + kAlignmentB, + output_accum_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Policy::Operator::InstructionShape, + typename DefaultGemm::Mma::Policy::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert( + WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, + ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + }; + + static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; + static constexpr int64_t kAlignmentK = MM0::kAlignmentB; + static constexpr int64_t kAlignmentV = 1; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array out_rescale; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::Mma::SharedStorage mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::Mma::SharedStorage mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + }; + + using SharedStorage = typename cutlass::platform::conditional< + kSingleValueIteration || kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); + CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); + CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + if (kSupportsBias) { + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + XFORMERS_CHECK( + p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + } + XFORMERS_CHECK( + p.q_strideM % kAlignmentQ == 0, + "query is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.k_strideM % kAlignmentK == 0, + "key is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.v_strideM % kAlignmentV == 0, + "value is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0, + "query is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0, + "key is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, + "value is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "invalid value for `custom_mask_type`"); + return true; + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) { + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& mi = shared_storage.mi; + auto& out_rescale = shared_storage.out_rescale; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + out_rescale[thread_id()] = accum_t(1.0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)p.o_strideM}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{ + (int32_t)(p.head_dim_value * p.num_heads)}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + +#ifdef HAS_PYTORCH + curandStatePhilox4_32_10_t curand_state_init; + if (kSupportsDropout && p.use_dropout) { + const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); + + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + p.dropout_batch_head_rng_offset, + &curand_state_init); + } +#endif + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + int32_t(kKeysPerBlock), p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue( + shared_storage.after_mm0.mm1, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{ + tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{ + tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(p.q_strideM)), + p.query_ptr, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(p.k_strideM)), + p.key_ptr + iter_key_start * p.k_strideM, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_uniform(warp_id()); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } else { + MM1::Mma::drain_cp_asyncs(); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + // multiply by scaling factor + if (kSupportsBias) { + accum = + cutlass::multiplies()(p.scale, accum); + } + + // apply attention bias if applicable + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MM0::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + // attn_bias_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start, + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id()); + MM0::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + + // Mask out last if causal + // This is only needed if upper-right corner of current query / key block + // intersects the mask Coordinates of upper-right corner of current block + // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The + // first masked element is x = y + offset -> query_start + offset There is + // intersection (and we need to mask) if min(iter_key_start + + // kKeysPerBlock, num_keys)) >= query_start + offset + if (p.custom_mask_type && + cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >= + (query_start + p.causal_diagonal_offset)) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + int32_t last_col; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + // last absolute col is (last absolute query + offset) + // last local col is (last absolute query + offset - + // iter_key_start) + last_col = query_start + accum_m + p.causal_diagonal_offset - + iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + out_rescale, + shared_storage.addition_storage, + my_lane_id, + thread_id(), + my_warp_id, + p.num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : p.scale); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + +#ifdef HAS_PYTORCH + // apply dropout (if applicable) after we've written Pij to smem. + // dropout is applied by multiplying each element of Pij by: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + // + // for backward purposes we want to be able to map each element of the + // attention matrix to the same random uniform number as the one we used + // in forward, without needing to use the same iteration order or having + // to store the dropout matrix. its possible to do this in registers but + // it ends up being very slow because each thread having noncontiguous + // strips of the Pij tile means we have to skip around a lot, and also + // have to generate a single random number at a time + if (kSupportsDropout && p.use_dropout) { + auto si = shared_storage.after_mm0.si.accum_ref(); + // each thread handles a contiguous sequence of elements from Sij, all + // coming from the same row. the reason they have to come from the same + // row is that the sampling random numbers from a contiguous random + // number sequence is much more efficient than jumping around, and the + // linear offset of each element of S (the global matrix) maps to an + // offset in a random number sequence. for S, the end of a row and the + // beginning of the next have adjacent offsets, but for Sij, this is not + // necessarily the case. + const int num_threads = blockDim.x * blockDim.y * blockDim.z; + const int threads_per_row = + cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(problem_size_0_n, threads_per_row), 4); + + const int thread_i = thread_id() / threads_per_row; + const int thread_start_j = + (thread_id() % threads_per_row) * elts_per_thread; + + if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + static_cast( + (query_start + thread_i) * p.num_keys_absolute + + (iter_key_start + thread_start_j)), + &curand_state); + const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); + + // apply dropout scaling to elements this thread is responsible for, + // in chunks of 4 + for (int sij_start_col_idx = thread_start_j; sij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + problem_size_0_n); + sij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + si.at({thread_i, sij_start_col_idx + quad_idx}) *= + static_cast( + dropout_scale * + ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); + } + } + } + __syncthreads(); // p.use_dropout should have same value kernel-wide + } +#endif + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kSingleValueIteration + ? 1 + : ceil_div( + (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv( + // operand A: Pij_dropped in shared memory + shared_storage.after_mm0.si.accum_ref(), + // operand B: shared memory staging area for Vj, which is loaded + // from global memory + shared_storage.after_mm0.mm1.operand_B_ref(), + (int)thread_id(), + (int)my_warp_id, + (int)my_lane_id); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + MM1::Mma::drain_cp_asyncs(); + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = + typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast::value, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst::value, + kIsLast::value, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast::value, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = call_conditional< + kIsLast::value, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + my_warp_id, + my_lane_id); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + MM1::Mma::drain_cp_asyncs(); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (thread_id() < p.num_queries) { + p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) + + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = + cutlass::platform::numeric_limits::infinity(); + } + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + cutlass::Array& out_rescale, + cutlass::Array& + addition_storage, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kWarpSize>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + + static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + // Doing this `exp` is quite expensive. Let's + // split it across the warps + bool restore_mi_to_minus_inf = false; + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + auto m_prime_id = m_prime[id]; + auto mi_id = mi[id]; + bool changed = m_prime_id < mi_id; // `false` if both are -inf + if (changed) { + auto m_prime_exp = exp2f(m_prime_id - mi_id); + out_rescale[id] = m_prime_exp; + s_prime[id] *= m_prime_exp; + } else { + // Only when bias is enabled, it's possible that all the first values + // of attention are masked to `-inf`. In that case we want to avoid + // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 + if (kSupportsBias && + mi_id == -cutlass::platform::numeric_limits::infinity()) { + restore_mi_to_minus_inf = true; + mi[id] = 0.0f; + } + out_rescale[id] = 1.0f; + } + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !is_first) { + accum_t line_rescale; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag_o[idx] = frag_o[idx] * line_rescale; + }, + [&](int accum_m) {}); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + // NOTE: we could atomically add `total_row` to `s_prime`, but + // it's faster (and deterministic) to avoid atomics here + addition_storage + [accum_m + kQueriesPerBlock * tile_offset.column()] = + total_row; + } + }); + } + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + accum_t total_row = s_prime[id]; + if (restore_mi_to_minus_inf) { + // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` + mi[id] = -cutlass::platform::numeric_limits::infinity(); + } else { + m_prime[id] = mi[id]; + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x; + } + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.y; + } + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x + threadIdx.y * blockDim.x; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..536bdb4305d112343f5bf348ff6a93e68a520d0b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py @@ -0,0 +1,144 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from typing import List +import torch +import subprocess +import sys +import tempfile +import os +import numpy as np + + +TORCH_DTYPE_NAME = { + torch.float32: "f32", + torch.float16: "f16", + torch.bfloat16: "b16" +} +NAME_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_NAME.items()} + +def _tensor_from_storage(tensor: torch.Tensor, dtype) -> torch.Tensor: + # PyTorch >= 2.0 + if hasattr(tensor, 'untyped_storage'): + return torch.tensor([], dtype=dtype).set_(tensor.untyped_storage()) + return torch.tensor([], dtype=dtype).set_(tensor.storage().untyped()) + +class PipedSubprocess: + def __init__(self, binary: str) -> None: + self.binary = binary + self.tempdir_ctx = tempfile.TemporaryDirectory() + + def __enter__(self) -> "PipedSubprocess": + self.subp = subprocess.Popen(self.binary, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True, bufsize=0) + self.tempdir = self.tempdir_ctx.__enter__() + self.file_counter = 0 + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.tempdir_ctx.__exit__(exc_type, exc_val, exc_tb) + + def temp_filename(self, suffix: str) -> str: + self.file_counter += 1 + return os.path.join(self.tempdir, f"{self.file_counter}{suffix}") + + def write(self, *args) -> None: + for a in args: + self.subp.stdin.write(str(a) + " ") + + def writeTensor(self, tensor: torch.Tensor, name: str, stride_names: List[str]) -> None: + print(f"Py ->C++: {TORCH_DTYPE_NAME[tensor.dtype]}:{name}") + tensor_u8 = _tensor_from_storage(tensor, torch.uint8) + self.write("tensor_begin", f"{TORCH_DTYPE_NAME[tensor.dtype]}:{name}", tensor_u8.shape[0]) + filename = self.temp_filename(f"{name}.tensor") + assert tensor.storage_offset() == 0 + with open(filename, "wb+") as fd: + fd.write(bytes(tensor_u8.numpy())) + self.write("file", filename) + self.write("tensor_end") + + for stride_name, stride_value in zip(stride_names, tensor.stride()): + self.write(stride_name, stride_value) + + def readTensor(self, name, stride_name, shape) -> torch.Tensor: + tmpfile = self.temp_filename(f"{name}.tensor") + self.write("tmpfile", tmpfile) + + self.readExpect("tensor_begin") + dtype_str, name = self.read().split(":") + print(f"C++->Py : {dtype_str}:{name}") + u8len = int(self.read()) + dtype = NAME_TORCH_DTYPE[dtype_str] + + self.readExpect("file") + self.readExpect(tmpfile) + + with open(tmpfile, "rb") as fd: + data = fd.read(u8len) + # `np.array` is not strictly needed, but avoids a torch warning + tensor_u8 = torch.frombuffer(np.array(data), dtype=torch.uint8, count=u8len) + self.readExpect("tensor_end") + + tensor = _tensor_from_storage(tensor_u8, dtype) + strides = [] + for sn in stride_name: + self.readExpect(sn) + strides.append(int(self.read())) + if len(strides) != shape: + strides.append(1) + assert len(strides) == len(shape), name + return torch.as_strided(tensor, shape, strides) + + def readNamed(self, name: str): + self.readExpect(name) + return self.read() + + def readExpect(self, what: str) -> None: + r = self.read() + if r != what: + raise ValueError(f"Read {r} but expected {what}") + + def read(self): + read_all = [] + # Skip initial whitespace + while True: + r = self.subp.stdout.read(1) + if r not in [' ', "\n"]: + read_all.append(r) + break + # Read data + while True: + r = self.subp.stdout.read(1) + if r in [' ', "\n"]: + break + read_all.append(r) + return ''.join(read_all) + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h new file mode 100644 index 0000000000000000000000000000000000000000..048c1e019b07495a9f529e9e3c678471a98f5e98 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h @@ -0,0 +1,90 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +template < + typename scalar_t, // scalar type + typename ThreadblockTileShape, // size of tile to load + int Threads, // number of participating threads + int ElementsPerAccess> // thread access width in elements +class TileSmemLoader { + public: + using SmemTile = + cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape< + ThreadblockTileShape::kColumn, // contiguous + ThreadblockTileShape::kRow>, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load( + GmemTileIterator tile_load_iter, + SmemTileIterator tile_store_iter) { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8f76fe1f30a2987061571fc948723ff6d22a4c89 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" + +// #include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "fused_bias_act_epilogue.h" +#include "../warp/fused_bias_act_fragment_iterator_tensor_op.h" +#include "output_tile_thread_map_for_fused_bias.h" +#include "default_thread_map_tensor_op_for_fused_bias.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + typename Shape_, + typename WarpMmaTensorOp_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultFusedBiasActEpilogueTensorOp { + + using Shape = Shape_; + using WarpMmaTensorOp = WarpMmaTensorOp_; + static int const kPartitionsK = PartitionsK; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaTensorOp::LayoutC; + using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOpForFusedBias< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess + >::Type; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + OutputTileThreadMap, + ElementOutput + >; + + using AccumulatorFragmentIterator = typename std::conditional::value, + cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>, + cutlass::epilogue::warp::FusedBiasActFragmentIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC> >::type; + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::FusedBiasActEpilogue< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + OutputOp + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h new file mode 100644 index 0000000000000000000000000000000000000000..35fe758ae707180456cc695cbe0ffabb6cd065c5 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h @@ -0,0 +1,113 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + +*/ + +#pragma once + +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/pitch_linear.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines the optimal thread map for TensorOp accumulator layouts +template < + typename ThreadblockShape_, + typename WarpShape_, + int PartitionsK, + typename Element_, + int ElementsPerAccess +> +struct DefaultThreadMapTensorOpForFusedBias { + + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + static int const kPartitionsK = PartitionsK; + using Element = Element_; + static int const kElementsPerAccess = ElementsPerAccess; + + // + // Definitions + // + + struct Detail { + + /// Tensor Operations fundamentally perform operations on 8 rows + static int const kTensorOpRows = 8; + static int const kWarpSize = 32; + + static_assert( + !(ThreadblockShape::kM % WarpShape::kM) && + !(ThreadblockShape::kM % WarpShape::kM), "Divisibility"); + + /// Number of warps + using WarpCount = gemm::GemmShape< + ThreadblockShape::kM / WarpShape::kM, + ThreadblockShape::kN / WarpShape::kN, + kPartitionsK + >; + + /// Number of participating threads + static int const kThreads = WarpCount::kCount * kWarpSize; + }; + + // + // ThreadMap + // + + /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap + using Type = OutputTileOptimalThreadMapBiasAct < + OutputTileShape, + OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>, + Detail::kThreads, + kElementsPerAccess, + sizeof_bits::value + >; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h new file mode 100644 index 0000000000000000000000000000000000000000..04d988cca4aab92c67bab562a8120bb32a68ff43 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h @@ -0,0 +1,213 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator without splitk +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename OutputOp_ ///< Output operator +> +class FusedBiasActEpilogue { + +public: + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using OutputOp = OutputOp_; + + /// Output layout is always row-major + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + +public: + + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +public: + + /// Constructor + CUTLASS_DEVICE + FusedBiasActEpilogue( + ){ } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators, + OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + bool need_bias = output_op.is_source_needed(); + + if (need_bias) + compute_source_needed_(output_op, accumulators, fused_bias_act_accumlators, source_iterator); + else + compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); + + + } + + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); + } + + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators, + OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + typename OutputTileIterator::Fragment source_fragment; + + + source_fragment.clear(); + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + source_iterator.load(source_fragment); + ++source_iterator; + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; + fused_bias_act_fragment = output_op(accum_fragment, source_fragment); + + fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); + ++fused_bias_act_fragment_iterator; + } + } + + CUTLASS_DEVICE + void compute_source_no_needed_( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); + + + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < AccumulatorFragmentIterator::kIterations; ++iter) { + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; + fused_bias_act_fragment = output_op(accum_fragment); + + fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); + ++fused_bias_act_fragment_iterator; + } + } + +}; + + + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h new file mode 100644 index 0000000000000000000000000000000000000000..c92307f5774652b54141970f7bf8ee878e01a060 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/fast_math.h" + +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// RowArrangement determines how one or more warps cover a region of consecutive rows. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize, + bool Is2dTile +> +struct RowArrangementBiasAct; + +/// RowArrangement in which each warp's access is a 1D tiled arrangement. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize +> +struct RowArrangementBiasAct { + static int const kWarpSize = 32; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + static int const kIterationsRow = 1; + static int const kDeltaRow = 1; + static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize; + static int const kDeltaColumn = kWarpSize * kElementsPerAccess; + + static int const kAccessWidth = kWarpSize; + static int const kAccessRows = 1; + static int const kWarpPartitionsRow = 1; + static int const kWarpPartitionsColumn = WarpsRemaining; +}; + +/// RowArrangement in which each warp's access is a 2D tiled arrangement. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize +> +struct RowArrangementBiasAct { + + static int const kMemoryAccessSize = 4;//128; + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + struct Detail { + static int const kShapeRow = Shape::kRow / WarpsRemaining; + static int const kShapeWidth = Shape::kColumn / kElementsPerAccess; + + static int const kTargetMemoryAccessWidth = + kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8); + + static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth; + }; + + static int const kAccessWidth = + (Detail::kTargetAccessRows > Detail::kShapeRow ? + kWarpSize / Detail::kShapeRow + : const_min( + Detail::kShapeWidth, + const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8)) + )); + + static int const kAccessRows = + (Detail::kTargetAccessRows > Detail::kShapeRow ? + Detail::kShapeRow + : const_min(Shape::kRow, kWarpSize / kAccessWidth)); + + static int const kIterationsRow = Detail::kShapeRow / kAccessRows; + static int const kDeltaRow = kAccessRows; + + static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth; + static int const kDeltaColumn = kAccessWidth * kElementsPerAccess; + + static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access"); + static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" ); + static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" ); + + static int const kWarpPartitionsRow = 1; + static int const kWarpPartitionsColumn = 1; +}; + +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Template metaprogram for partitioning a 4D space across warps to achieve several performance +/// objectives: +/// +/// - coalesced memory accesses in units of 16 Byte lines +/// - minimal address arithmetic +/// - minimal predicate calculations +/// +template < + typename Shape_, + typename Count_, + int Threads, + int ElementsPerAccess, + int ElementSize +> +struct OutputTileOptimalThreadMapBiasAct { + + using Shape = Shape_; + using Count = Count_; + + static int const kWarpSize = 32; + static int const kThreads = Threads; + static int const kWarpCount = kThreads / kWarpSize; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + // + // Metaprogram computation + // + + struct Detail { + + // Clusters + static int const kIterationsCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kCluster / kWarpCount + : 1); + + static int const kDeltaCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster + : 1); + + static int const kCompactedDeltaCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster + : 1); + + static int const kWarpPartitionsCluster = + ((Shape::kCluster > kWarpCount) ? + kWarpCount + : kWarpCount / Shape::kCluster); + + static int const kWarpsRemainingForGroups = + ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster); + + // Groups + static int const kIterationsGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kGroup / kWarpsRemainingForGroups + : 1); + + static int const kDeltaGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup + : 1); + + static int const kCompactedDeltaGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kRow * Shape::kGroup / kIterationsGroup + : 1); + + static int const kWarpPartitionsGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + 1 + : kWarpsRemainingForGroups / Shape::kGroup); + + static int const kWarpsRemainingForRows = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + 1 + : kWarpsRemainingForGroups / Shape::kGroup); + + // Rows + using RowArrangement = detail::RowArrangementBiasAct< + Shape, + kWarpsRemainingForRows, + kElementsPerAccess, + kElementSize, + (Shape::kRow > kWarpsRemainingForRows) + >; + + // Warp partitions + using WarpPartitions = OutputTileShape< + RowArrangement::kWarpPartitionsColumn, + RowArrangement::kWarpPartitionsRow, + kWarpPartitionsGroup, + kWarpPartitionsCluster, + 1>; + + static int const kAccessWidth = RowArrangement::kAccessWidth; + static int const kAccessRows = RowArrangement::kAccessRows; + }; + + // + // Output + // + + using Iterations = OutputTileShape< + Detail::RowArrangement::kIterationsColumn, + Detail::RowArrangement::kIterationsRow, + Detail::kIterationsGroup, + Detail::kIterationsCluster, + 1>; + + using Delta = OutputTileShape< + Detail::RowArrangement::kDeltaColumn, + Detail::RowArrangement::kDeltaRow, + Detail::kDeltaGroup, + Detail::kDeltaCluster, + 1>; + + /// Initial offset function + CUTLASS_HOST_DEVICE + static MatrixCoord initial_offset(int thread_idx) { + + int warp_idx = thread_idx / kWarpSize; + int lane_idx = thread_idx % kWarpSize; + + // Compute warp location + int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; + int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; + + int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; + int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; + + int row_idx = residual_group / Detail::WarpPartitions::kRow; + int col_idx = residual_group % Detail::WarpPartitions::kRow; + + // Compute per-lane offset + int lane_row_offset = lane_idx / Detail::kAccessWidth; + int lane_col_offset = lane_idx % Detail::kAccessWidth; + + // Compute coordinate in output space + int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup; + int group_offset = group_idx * Shape::kRow * Count::kRow; + int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; + int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; + + return MatrixCoord( + cluster_offset + group_offset + row_offset + lane_row_offset, + (column_offset + lane_col_offset) * kElementsPerAccess + ); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..88259d102e3967905e9a6189a8fd3d6b56ba67ee --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h @@ -0,0 +1,189 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/epilogue/warp/tensor_op_policy.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) + typename Layout ///< target shared memory layout +> +class FusedBiasActFragmentIteratorTensorOp; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array) +> +class FusedBiasActFragmentIteratorTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::RowMajor; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + OperatorElementC, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + OperatorElementC, + OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; + + using OutputAccumulatorTile = AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + +private: + + /// Internal access type + using AccessType = Array; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp(AccumulatorTile &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + frag_ptr[n] = accumulators_[accumulator_access_offset]; + } + } + /// Stores a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void store(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + accumulators_[accumulator_access_offset] = frag_ptr[n]; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cd4417bbb56834835d1dfc9f812a8e2ff628e9df --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h @@ -0,0 +1,427 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace warp { + + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of the accumulation tile shape (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Whether beta is zero + bool IsBetaZero_ > +class MmaTensorOpPureFragmentIterator; + + +// Partial specialization for col-major accumulator tile +// And Element type is the same as Accumulator Element type + +template < + /// Shape of warp tile to load (concept: MatrixShape) + typename Shape_, + /// Shape of the warp accumulation tile (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_> +class MmaTensorOpPureFragmentIterator { + public: + + /// Shape of warp tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of the warp accumulation tile (concept: MatrixShape) + using AccumulatorShape = AccumulatorShape_; + + /// KBlocks columns to compute residual + static int const kKBlockColumn = KBlocksColumn_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::ColumnMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Whether beta is zero + static bool const IsBetaZero = true; + + /// Number of participating threads + static int const kThreads = 32; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + static_assert( + !(AccumulatorShape::kRow % Shape::kRow) && + !(AccumulatorShape::kColumn % Shape::kColumn), + "Shape of Warp Accumulator must be divisible by warp shape."); + static_assert( + !(kKBlockColumn % Shape::kColumn), + "KBlock size must be divisible by warp shape."); + + /// Number of times this iterator can be incremented + static int const kIterations = AccumulatorShape::kCount / Shape::kCount; + }; + +private: + + static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; + + /// Number of mma operations performed by a warp + using MmaIterations = MatrixShape; + /// Number of mma operations performed by the entire accumulator + using AccumulatorIterations = MatrixShape; + + /// Number of K iterations + static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; + static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + static int const kResidualIndex = kResidualColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// Accumulator Fragment object + using AccumulatorFragment = Array; + + +private: + + /// Internal access type + using AccessType = Array; + +private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + /// Used to access residual tile first + bool is_residual_tile_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0), is_residual_tile_(true) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { + index_ = index_ - kKBlockColumnIterations + kResidualIndex; + is_residual_tile_ = false; + } + } + + /// Increments + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator--() { + add_offset(-1); + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + AccessType src_fragment; + src_fragment.clear(); + + + AccessType *frag_ptr = reinterpret_cast(&frag); + + int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; + int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow + * MmaIterations::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; n++) { + for (int m = 0; m < MmaIterations::kRow; m++) { + int accumulator_access_offset = + (n + index_n) * AccumulatorIterations::kRow + m + index_m; + + frag_ptr[n * MmaIterations::kRow + m].clear(); + if(!(is_residual_tile_ && index_ >= kResidualIndex)) + frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset]; + // frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment); + } + } + } + +}; + +// Partial specialization for row-major accumulator tile + +template < + /// Shape of warp tile to load (concept: MatrixShape) + typename Shape_, + /// Shape of the warp accumulation tile (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_> +class MmaTensorOpPureFragmentIterator { + public: + + /// Shape of warp tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of the warp accumulation tile (concept: MatrixShape) + using AccumulatorShape = AccumulatorShape_; + + /// KBlocks columns to compute residual + static int const kKBlockColumn = KBlocksColumn_; + + /// Accumulator Element type + using ElementAccumulator = ElementAccumulator_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Whether beta is zero + static bool const IsBetaZero = true; + + /// Number of participating threads + static int const kThreads = 32; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + static_assert( + !(AccumulatorShape::kRow % Shape::kRow) && + !(AccumulatorShape::kColumn % Shape::kColumn), + "Shape of Warp Accumulator must be divisible by warp shape."); + static_assert( + !(kKBlockColumn % Shape::kColumn), + "KBlock size must be divisible by warp shape."); + + /// Number of times this iterator can be incremented + static int const kIterations = AccumulatorShape::kCount / Shape::kCount; + }; + +private: + + static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; + + /// Number of mma operations performed by a warp + using MmaIterations = MatrixShape; + /// Number of mma operations performed by the entire accumulator + using AccumulatorIterations = MatrixShape; + + /// Number of K iterations + static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; + static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + static int const kResidualIndex = kResidualColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// Accumulator Fragment object + using AccumulatorFragment = Array; + + +private: + + /// Internal access type + using AccessType = Array; + using FragmentAccessType = Array; + +private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + /// Used to access residual tile first + bool is_residual_tile_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0), is_residual_tile_(true) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { + index_ = index_ - kKBlockColumnIterations + kResidualIndex; + is_residual_tile_ = false; + } + } + + /// Increments + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator--() { + add_offset(-1); + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + + FragmentAccessType src_fragment; + src_fragment.clear(); + + FragmentAccessType *frag_ptr = reinterpret_cast(&frag); + + int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; + int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow + * MmaIterations::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; m++) { + for (int n = 0; n < MmaIterations::kColumn; n++) { + int accumulator_access_offset = + (m + index_m) * AccumulatorIterations::kColumn + n + index_n; + + frag_ptr[m * MmaIterations::kColumn + n].clear(); + if(!(is_residual_tile_ && index_ >= kResidualIndex)) + frag_ptr[m * MmaIterations::kColumn + n] = (accumulators_[accumulator_access_offset]); + } + } + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py new file mode 100644 index 0000000000000000000000000000000000000000..15b10296d1bdf9c59b21fd91497881601bb91617 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py @@ -0,0 +1,129 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_turing_and_volta as api_generator +import gen_sample as sample_creater +import gen_cmake as cmake_creater +import gen_verify as verify_creater +import gen_device as b2b_fused_generator +import replace_fix_impl_header + +import argparse +import os +import json + + +parser = argparse.ArgumentParser(description="Generates Fused Multi-GEMM CUTLASS Kernels") +parser.add_argument("--config-file", default="config.json", help="JSON file containing configuration to generate") +parser.add_argument("--gen-name", default="FusedMultiGemmForward", help="Specific the output name") +parser.add_argument("--output-dir", default="", help="Specifies the output dir") +parser.add_argument("--cutlass-dir", default="", help="Specifies the dependent CUTLASS repo dir") +parser.add_argument("--gen-include-cutlass-dir", default="", help="Specifies the generated CUTLASS code include dir, if needed.") +args = parser.parse_args() + +gen_name = args.gen_name + +cutlass_deps_dir = args.cutlass_dir + +output_dir = args.output_dir +output_dir += "/" + +cutlass_deps_root = args.gen_include_cutlass_dir +if cutlass_deps_root == '': + cutlass_deps_root = cutlass_deps_dir + "/include/" +cutlass_deps_root +='/' + + +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +if not os.path.exists(output_dir + "/" + "auto_gen"): + os.mkdir(output_dir + "/" + "auto_gen") + +if not os.path.exists(output_dir + "/" + "fixed_impl"): + os.mkdir(output_dir + "/" + "fixed_impl" ) + +if not os.path.exists(output_dir + "/" + "sample"): + os.mkdir(output_dir + "/" + "sample" ) + +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "device"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "device") +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "kernel"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "kernel") +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "threadblock"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "threadblock") + +with open(args.config_file, 'r') as infile: + gemm_info_dict = json.load(infile) + +keys = sorted(gemm_info_dict.keys()) +fuse_gemm_info = [gemm_info_dict[k] for k in keys] + + +for_cutlass_gen_user_include_header_file = [ + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h", + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h", +] + +for_fused_wrapper = [ + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h", + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h", + "auto_gen/device/" + gen_name + ".h", + cutlass_deps_root + "cutlass/gemm/device/gemm_batched.h", + cutlass_deps_root + "cutlass/cutlass.h", +] + +# Copy fixed implementation to the output directory +fix_impl = replace_fix_impl_header.replace_fix_impl("../fixed_impl/", output_dir +"/fixed_impl/", cutlass_deps_root) +fix_impl.gen_code() + +auto_gen_output_dir = output_dir + "/auto_gen/" +project_root = "" +turing_plus = b2b_fused_generator.gen_device(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, cutlass_deps_root, project_root, auto_gen_output_dir) +turing_plus.gen_code(75, 'hmma1688', False) + +api = api_generator.gen_one_API(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir) +api.gen_code() + +# Generate C++ sample +os.system("cp ../leaky_bias.h " + output_dir + "/sample/") +os.system("cp ../utils.h " + output_dir + "/sample/") + +sample_dir = output_dir + "/sample/" +sample = sample_creater.gen_test(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, sample_dir) +sample.gen_cpp_sample() + +cmake_gen = cmake_creater.gen_build_sys(cutlass_deps_dir, output_dir) +cmake_gen.gen_code() + +verify = verify_creater.gen_verify(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir) +verify.gen_code() diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py new file mode 100644 index 0000000000000000000000000000000000000000..c99ac9ae2c383ae030621280eac899f453ab517b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py @@ -0,0 +1,131 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +class gen_build_sys: + def __init__(self, cutlass_deps_dir, output_dir = "../"): + self.output_dir = output_dir + self.cutlass_deps_dir = cutlass_deps_dir + + def gen_top(self): + code = "" + code += '''\ +# Auto Generated code - Do not edit. + +cmake_minimum_required(VERSION 3.8) +project(CUTLASS_MULTI_GEMMS LANGUAGES CXX CUDA) +find_package(CUDAToolkit) +set(CUDA_PATH ${{CUDA_TOOLKIT_ROOT_DIR}}) +set(CUTLASS_PATH \"{cutlass_deps_dir}/include\") +set(CUTLASS_UTIL_PATH \"{cutlass_deps_dir}/tools/util/include\") +list(APPEND CMAKE_MODULE_PATH ${{CUDAToolkit_LIBRARY_DIR}}) +'''.format(cutlass_deps_dir=self.cutlass_deps_dir) + + code += '''\ +set(GPU_ARCHS \"\" CACHE STRING + \"List of GPU architectures (semicolon-separated) to be compiled for.\") + +if(\"${GPU_ARCHS}\" STREQUAL \"\") + set(GPU_ARCHS \"70\") +endif() + +foreach(arch ${GPU_ARCHS}) + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -gencode arch=compute_${arch},code=sm_${arch}\") + if(SM STREQUAL 70 OR SM STREQUAL 75) + set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -DWMMA\") + set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -DWMMA\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -DWMMA\") + endif() +endforeach() + +set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS}\") +set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS}\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -Wall\") + +set(CMAKE_C_FLAGS_DEBUG \"${CMAKE_C_FLAGS_DEBUG} -Wall -O0\") +set(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0\") +set(CMAKE_CUDA_FLAGS_DEBUG \"${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall\") + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +if(CMAKE_CXX_STANDARD STREQUAL \"11\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-extended-lambda\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr\") +endif() + +set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -O3\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -O3\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-strict-aliasing\") + +set(COMMON_HEADER_DIRS + ${PROJECT_SOURCE_DIR} + ${CUDAToolkit_INCLUDE_DIRS} +) + +set(COMMON_LIB_DIRS + ${CUDAToolkit_LIBRARY_DIR} +) +list(APPEND COMMON_HEADER_DIRS ${CUTLASS_PATH}) +list(APPEND COMMON_HEADER_DIRS ${CUTLASS_UTIL_PATH}) +''' + code += '''\ +include_directories( + ${COMMON_HEADER_DIRS} +) + +link_directories( + ${COMMON_LIB_DIRS} +) + +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +add_definitions(-DGOOGLE_CUDA=1) + +add_executable(sample + sample/sample.cu + one_api.cu +) +target_link_libraries(sample PRIVATE + -lcudart + -lnvToolsExt + ${CMAKE_THREAD_LIBS_INIT} +) + +if(NOT DEFINED LIB_INSTALL_PATH) + set(LIB_INSTALL_PATH ${CMAKE_CURRENT_BINARY_DIR}) +endif() +''' + return code + + def gen_code(self): + top_code = self.gen_top() + with open(self.output_dir + "CMakeLists.txt", "w") as f: + f.write(top_code) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py new file mode 100644 index 0000000000000000000000000000000000000000..5abd22c81ee48ae5ce5bbbb88c8c80d3dafe5025 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py @@ -0,0 +1,120 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ast + +fuse_gemm_info = [ + { + 'epilogue': { + 'tp': 'LeakyRelu', #'CustomizedLeaky_RELU' + 'bias': {'addbias': False, 'bias_tp': 'mat'}, + 'args': [('float', 'leaky_alpha', 1.3), ], + 'func': ''' +y = max(leaky_alpha * x, x) +y = y * x + ''' + } + }, + +] +class AnalysisNodeVisitor(ast.NodeVisitor): + def visit_Import(self,node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_ImportFrom(self,node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Assign(self,node): + print('Node type: Assign and fields: ', node._fields) + # print('Node type: Assign and targets value: ', node.targets, node.value) + + ast.NodeVisitor.generic_visit(self, node) + + def visit_BinOp(self, node): + print('Node type: BinOp and fields: ', node._fields) + print('node op: ', type(node.op).__name__) + ast.NodeVisitor.generic_visit(self, node) + + def visit_Expr(self, node): + print('Node type: Expr and fields: ', node._fields) + ast.NodeVisitor.generic_visit(self, node) + + def visit_Num(self,node): + print('Node type: Num and fields: ', node._fields) + print('Node type: Num: ', node.n) + + def visit_Name(self,node): + print('Node type: Name and fields: ', node._fields) + print('Node type: Name and fields: ', type(node.ctx).__name__, node.id) + + ast.NodeVisitor.generic_visit(self, node) + + def visit_Str(self, node): + print('Node type: Str and fields: ', node._fields) + +class CodeVisitor(ast.NodeVisitor): + def visit_BinOp(self, node): + if isinstance(node.op, ast.Add): + node.op = ast.Sub() + self.generic_visit(node) + + def visit_Assign(self, node): + print('Assign %s' % node.value) + self.generic_visit(node) + + def visit_Name(self, node): + print("Name:", node.id) + self.generic_visit(node) + + + def visit_FunctionDef(self, node): + print('Function Name:%s'% node.name.op) + self.generic_visit(node) + func_log_stmt = ast.Print( + dest = None, + values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)], + nl = True, + lineno = 0, + col_offset = 0, + ) + node.body.insert(0, func_log_stmt) + +visitor = AnalysisNodeVisitor() + +code = \ +''' + +a=max(leaky_alpha * x, x +1) + +''' + +visitor.visit(ast.parse(code)) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd01ef16bb078e627cfaf910911b2c16baab76e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py @@ -0,0 +1,469 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from typing import * + +import helper +import gen_ir + +import gen_kernel as gen_ker + + +class gen_device: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, cutlass_deps_root, project_root, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.raw_gemm_info = fuse_gemm_info + self.b2b_num = len(fuse_gemm_info) + self.user_header_file = user_header_file + self.args = {} + # device arg struct memebr + self.arg_member = [] + self.gen_class_name = gen_class_name + self.gen_kernel_name = gen_class_name + "Kernel" + self.template_args = [] + self.__tempalate_arg_list = {'Stages': int, 'SplitKSerial': bool, 'IsBetaZero': bool, 'AlignmentA': int, 'AlignmentB': int} + + self.file_name = output_dir + "/device/" +gen_class_name +".h" + self.sample_dir = output_dir + + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + self.this_file_root = output_dir + "/device/" + + self.first_use_1stage = False + + ## gen kernel + self.gen_kernel = gen_ker.gen_kernel(self.template_args, self.gen_class_name, self.b2b_num, output_dir, cutlass_deps_root, project_root) + + + def __check_arg_type(self, temp_arg): + if temp_arg in self.__tempalate_arg_list.keys(): + return self.__tempalate_arg_list[temp_arg] + + find_sub = False + for candidate_arg in self.__tempalate_arg_list.keys(): + if (temp_arg.find(candidate_arg) != -1): + return self.__tempalate_arg_list[candidate_arg] + + return 'typename' + + # def gen_B2b2bGemm_class(): + def set_arch(self, sm_cap, mma_tp): + if sm_cap == 75 or sm_cap == 80 or sm_cap == 86: + self.arch = "cutlass::arch::Sm" + str(sm_cap) + + if mma_tp is 'hmma1688': + self.mma_shape = [16, 8, 8] + self.mma_tp = 'hmma' + elif mma_tp is 'imma8816': + self.mma_tp = 'imma' + self.mma_shape = [8, 8, 16] + else: + return 0 + + def gen_include_header(self): + code = '''\ +/* Auto Generated code - Do not edit.*/ + +#pragma once + +#include \"{cutlass_root}cutlass/cutlass.h\" +#include \"{cutlass_root}cutlass/numeric_types.h\" +#include \"{cutlass_root}cutlass/arch/arch.h\" +#include \"{cutlass_root}cutlass/device_kernel.h\" + +#include \"{cutlass_root}cutlass/gemm/threadblock/threadblock_swizzle.h\" + +#include \"{cutlass_root}cutlass/gemm/device/default_gemm_configuration.h\" +#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination_relu.h\" +#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination.h\" + +#include \"{project_root}../kernel/b2b_gemm.h\" +#include \"{project_root}../kernel/default_b2b_gemm.h\" +'''.format(cutlass_root=self.cutlass_deps_root, project_root=self.project_root, this_file_root=self.this_file_root) + include_user_header = "" + for header in self.user_header_file: + include_user_header += "#include \"" + header + "\"\n" + return code + include_user_header + + def gen_code(self, sm_cap, mma_tp, ifprint = True): + self.set_arch(sm_cap, mma_tp) + + self.update_b2b_args() + print(self.fuse_gemm_info) + self.update_b2b_class_template_args() + + func_code = self.gen_all_func() + member_var_code = "private:\n typename B2bGemmKernel::Params params_;\n" + + gen_code = gen_ir.gen_template_class(self.gen_class_name, self.template_args, func_code + member_var_code) + code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("device", gen_code))) + + if ifprint: + print(code) + + print("[INFO]: Gen device code output Dir: is ", self.file_name) + with open(self.file_name, 'w+') as f: + f.write(code) + + + gen_kernel = self.gen_kernel.gen_code(self.first_use_1stage) + print(gen_kernel) + + def update_b2b_class_template_args(self): + for arg in self.args.keys(): + self.template_args.append([self.__check_arg_type(arg), arg, self.args[arg]]) + + def update_b2b_args(self): + + self.args['ElementA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_tp']) + self.args['LayoutA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_format']) + + cnt = 0 + + warp_M_tile = 32 + + # Determine maximum N_tile + Max_Ntile = 0 + for layer in self.fuse_gemm_info: + n_tile = layer['mnk'][1] + if n_tile > Max_Ntile: + Max_Ntile = n_tile + if Max_Ntile >= 256: + warp_M_tile = 16 + + stages_temp = [] + + for layer in self.fuse_gemm_info: + cnt_str = str(cnt) + B_tp_str= 'ElementB' + cnt_str + B_format_str = 'LayoutB' + cnt_str + C_tp_str= 'ElementC' + cnt_str + C_format_str = 'LayoutC' + cnt_str + Acc_str = 'ElementAccumulator' + cnt_str + + self.args[B_tp_str] = helper.type_2_cutlass_type(layer['B_tp']) + self.args[B_format_str] = helper.type_2_cutlass_type(layer['B_format']) + self.args[C_tp_str] = helper.type_2_cutlass_type(layer['C_tp']) + self.args[C_format_str] = helper.type_2_cutlass_type(layer['C_format']) + self.args[Acc_str] = helper.type_2_cutlass_type(layer['Acc_tp']) + + + mnk = layer['mnk'][:] + + tile_mnk = mnk[:] + + tile_mnk[2] = 32 # force the ktile is 32 + + #N tile gen + if mnk[1] > 1024: + assert(0) + elif mnk[1] > 512: + tile_mnk[1] = 1024 + elif mnk[1] > 256: + tile_mnk[1] = 512 + elif mnk[1] > 128: + tile_mnk[1] = 256 + elif mnk[1] > 64: + tile_mnk[1] = 128 + elif mnk[1] > 32: + tile_mnk[1] = 64 + else : + tile_mnk[1] = 32 + + if tile_mnk[1] == 512: + stages_temp.append(1) + else: + stages_temp.append(2) + + tile_mnk[0] = 4 * warp_M_tile + + + + epilogue_setted_type = helper.get_epilogue_tp(layer) + cutlass_epilogue_name = "LinearCombinationRelu" + if epilogue_setted_type.lower() == 'leakyrelu': + cutlass_epilogue_name = "LinearCombinationLeakyRelu" + elif epilogue_setted_type.lower() == 'identity': + cutlass_epilogue_name = "LinearCombination" + + epilogue_str = 'EpilogueOutputOp' + cnt_str + if cnt != len(self.fuse_gemm_info) - 1: + n = layer['mnk'][1] + Fragments = tile_mnk[1] // 8 * 2 + self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name + "" + else: + n = layer['mnk'][1] + n_mod_8 = n % 4 + N_align_elements = 1 + if n_mod_8 == 0: + N_align_elements = 8 + elif n_mod_8 == 4: + N_align_elements = 4 + elif n_mod_8 == 2 or n_mod_8 == 6: + N_align_elements = 2 + + self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "" + + + + ThreadBlockShape_str = 'ThreadblockShape' + cnt_str + + self.args[ThreadBlockShape_str] = helper.cvt_2_cutlass_shape(tile_mnk) + + WarpShape_str = 'WarpShape' + cnt_str + tile_mnk[0] = warp_M_tile + self.args[WarpShape_str] = helper.cvt_2_cutlass_shape(tile_mnk) + cnt += 1 + + + self.args['ElementD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_tp']) + self.args['LayoutD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_format']) + + self.args['InstructionShape'] = helper.cvt_2_cutlass_shape(self.mma_shape) + self.args['OperatorClass'] = 'arch::OpClassTensorOp' + self.args['ArchTag'] = self.arch + self.args['ThreadblockSwizzle'] = 'threadblock::GemmBatchedIdentityThreadblockSwizzle' + + + for i in range(self.b2b_num): + self.args[helper.var_idx('Stages', i)] = "2" + + self.args['AlignmentA'] = str(8) + self.args['AlignmentB'] = str(8) + self.args['SplitKSerial'] = 'false' + self.args['Operator'] = 'typename DefaultGemmConfiguration::Operator' + self.args['IsBetaZero'] = 'false' + + + def gen_using_kernel(self): + code = "using B2bGemmKernel = typename kernel::DefaultB2bGemm<\n" + code += " " + "ElementA,\n" + code += " " + "LayoutA,\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("ElementB", i) + ",\n" + code += " " + helper.var_idx("LayoutB", i) + ",\n" + code += " " + helper.var_idx("ElementC", i) + ",\n" + code += " " + helper.var_idx("LayoutC", i) + ",\n" + code += " " + helper.var_idx("ElementAccumulator", i) + ",\n" + code += " " + helper.var_idx("EpilogueOutputOp", i) + ",\n" + code += " " + helper.var_idx("ThreadblockShape", i) + ",\n" + code += " " + helper.var_idx("WarpShape", i) + ",\n" + + code += " " + "ElementD,\n" + code += " " + "LayoutD,\n" + code += " " + "InstructionShape,\n" + code += " " + "OperatorClass,\n" + code += " " + "ArchTag,\n" + code += " " + "ThreadblockSwizzle,\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("Stages", i) + ",\n" + + + code += " " + "AlignmentA,\n" + code += " " + "AlignmentB,\n" + code += " " + "SplitKSerial,\n" + code += " " + "Operator,\n" + code += " " + "IsBetaZero_\n" + + code += ">::B2bGemmKernel;\n\n" + + return code + + def gen_args(self): + + def gen_arg_member(b2b_num): + data_members = [] + + for i in range(b2b_num): + member_type = "GemmCoord" + member_name = "problem_size_" + str(i) + data_members.append((member_type, member_name)) + + member_type = "TensorRef" + member_name = "ref_A0" + data_members.append((member_type, member_name)) + + for i in range(b2b_num): + member_type = "TensorRef" + member_name = "ref_B" + str(i) + data_members.append((member_type, member_name)) + member_type = "TensorRef" + member_name = "ref_C" + str(i) + data_members.append((member_type, member_name)) + + member_type = "TensorRef" + member_name = helper.var_idx("ref_D", b2b_num - 1) + data_members.append((member_type, member_name)) + + for i in range(b2b_num): + member_type = "typename EpilogueOutputOp" + str(i) + "::Params" + member_name = "epilogue" + str(i) + data_members.append((member_type, member_name)) + + data_members.append(('int', 'batch_count')) + + return data_members + + def gen_arg_struct_default_ctor(struct_name, data_members, inital_param_num, inital_value): + constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \ + gen_ir.indentation + struct_name + " (): " + for i in range(inital_param_num): + final_param = ',' + if i == inital_param_num - 1: + final_param = '{ }' + constructs_code += data_members[i][1] + inital_value + final_param + + constructs_code += "\n" + return constructs_code + + def gen_arg_struct_ctor(struct_name, data_members): + constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \ + gen_ir.indentation + struct_name + " (\n" + cnt = 0 + param_num = len(data_members) + for param in data_members: + final = ',\n' + if cnt == param_num - 1: + final = '\n):\n' + constructs_code += gen_ir.indentation + param[0] + " " + param[1] + "_" + final + cnt += 1 + + cnt = 0 + for param in data_members: + final = '),\n' + if cnt == param_num - 1: + final = ") { }\n" + constructs_code += gen_ir.indentation + param[1] + "(" + param[1] + "_" + final + cnt += 1 + + constructs_code += "\n" + return constructs_code + + # (variable type, variable name) + struct_member = gen_arg_member(self.b2b_num) + self.arg_member = struct_member + + codeBody = "" + for each_member in struct_member: + codeBody += gen_ir.indentation + each_member[0] + " " + each_member[1] + ";\n" + + codeBody += gen_arg_struct_default_ctor("Arguments", struct_member, self.b2b_num, "(0,0,0)") + "\n" + codeBody += gen_arg_struct_ctor("Arguments", struct_member) + "\n" + struct_code = gen_ir.gen_struct("Arguments", codeBody) + return struct_code + + def gen_func_constructs(self): + code = self.gen_class_name +"() {}" + return code + + def gen_func_initialize(self): + code = "Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {\n" + \ + "// Determine grid shape\n" + \ + "ThreadblockSwizzle threadblock_swizzle;\n" + \ + "cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(\n" + \ + " args.problem_size_0, \n" + \ + " { ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK },\n" + \ + " args.batch_count);\n" + \ + "// Initialize the Params structure\n" + \ + "params_ = typename B2bGemmKernel::Params{\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.problem_size_", i) + ",\n" + code += " grid_shape,\n" + \ + " args.ref_A0.non_const_ref(),\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.ref_B", i) + ".non_const_ref(),\n" + code += helper.var_idx(" args.ref_C", i) + ".non_const_ref(),\n" + + code += helper.var_idx(" args.ref_D", self.b2b_num - 1) + ",\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.epilogue", i) + ",\n" + + code += " args.batch_count\n" + code += "};\n" + \ + "return Status::kSuccess;\n" + \ + "}\n" + return code + + def gen_func_run(self): + code = "Status run(cudaStream_t stream = nullptr) {\n" + \ + "\n" + \ + " ThreadblockSwizzle threadblock_swizzle;\n" + \ + "\n" + \ + " dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);\n" + \ + " dim3 block(B2bGemmKernel::kThreadCount, 1, 1);\n" + \ + "\n" + \ + " cudaError_t result;\n" + \ + "\n" + \ + " int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));\n" + \ + " if (smem_size >= (48 << 10)) {\n" + \ + " result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);\n" + \ + "\n" + \ + " if (result != cudaSuccess) {\n" + \ + " return Status::kErrorInternal;\n" + \ + " }\n" + \ + " }\n" + \ + " cutlass::Kernel<<>>(params_);\n" + \ + " result = cudaGetLastError();\n" + \ + " return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;\n" + \ + " }\n" + + return code + def gen_func_operator(self): + opeartor_with_arg_code = "Status operator()(\n" + \ + " Arguments const &args,\n" + \ + " void *workspace = nullptr,\n" + \ + " cudaStream_t stream = nullptr) {\n" + \ + " Status status = initialize(args, workspace);\n" + \ + " \n" + \ + " if (status == Status::kSuccess) {\n" + \ + " status = run(stream);\n" + \ + " }\n" + \ + " return status;\n" + \ + "}\n" + operator_code = "Status operator()(\n" + \ + " cudaStream_t stream = nullptr) {\n" + \ + " Status status = run(stream);\n" + \ + " return status;\n" + \ + "}\n" + return opeartor_with_arg_code + "\n" + operator_code + + def gen_all_func(self): + return self.gen_using_kernel() + "\n" + \ + self.gen_args() + "\n" + \ + self.gen_func_constructs() + "\n" + \ + self.gen_func_initialize() + "\n" + \ + self.gen_func_run() + "\n" + \ + self.gen_func_operator() diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..28df5665cd25e1c490b10c51cb556e5397241ca2 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py @@ -0,0 +1,249 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper + + +indentation = " " + + +def append_word(word): + code = "" + code += word + code += " " + return code + + +def gen_namespace(namespace, codeBody): + code_gen = "namespace " + namespace + " {\n" + code_gen += codeBody + code_gen += "} // namespace " + namespace + "\n" + return code_gen + + +def gen_expression(type, lval, rval = None): + code_gen = "" + code_gen += append_word(type) + code_gen += append_word(lval) + if rval is not None: + code_gen += append_word("=") + code_gen += append_word(rval) + return code_gen + + +def gen_class(name, codeBody, inheritance_code = None): + code_gen = "" + if inheritance_code is None: + code_gen = "class " + name + "{\n" + else: + code_gen = "class " + name + " : "+ inheritance_code + "{\n" + code_gen += codeBody + code_gen += "}; // class " + name + "\n" + return code_gen + + +def gen_struct(name, codeBody, specialized = None): + specialized_code = "" + if specialized is not None: + specialized_code = "<" + specialized + ">" + code_gen = "struct " + name + specialized_code + "{\n" + code_gen += codeBody + code_gen += "}; // struct " + name + "\n" + return code_gen + + +def gen_template_arg(arg_type, arg_name, default_val = None): + rval = None + if default_val is not None: + rval = str(default_val) + + arg_typename = "" + if arg_type is int: + arg_typename = "int" + elif arg_type is bool: + arg_typename = "bool" + else: + arg_typename = "typename" + + internal_arg_name = arg_name + "_" + + code_gen = indentation + code_gen += gen_expression(arg_typename, internal_arg_name, rval) + + return code_gen + + +def gen_template_args(args, set_default = True): + arg_len = len(args) + cnt = 1 + code_gen = "" + for arg_tuple in args: + arg_type = arg_tuple[0] + arg_name = arg_tuple[1] + arg_default_val = None + if len(arg_tuple) == 3 and set_default: + arg_default_val = arg_tuple[2] + + code_gen += gen_template_arg(arg_type, arg_name, arg_default_val) + if cnt != arg_len: + code_gen += ",\n" + cnt += 1 + + return code_gen + + +def gen_template_head(args, set_default = True): + code_gen = "template <\n" + code_gen += gen_template_args(args, set_default) + code_gen += ">\n" + return code_gen + + +def export_template_args(args): + code_gen = "public:\n" + for arg_tuple in args: + code_gen += indentation + arg_type = arg_tuple[0] + arg_name = arg_tuple[1] + internal_arg_name = arg_name + "_" + + typename = "" + if arg_type is int: + typename = "static int const" + elif arg_type is bool: + typename = "static bool const" + else: + typename = "using" + + code_gen += gen_expression(typename, arg_name, internal_arg_name) + code_gen += ";\n" + return code_gen + + +def gen_template_class(class_name, args, codeBody, set_default = True, inheritance_code = None): + code_gen = "" + + code_gen += gen_template_head(args, set_default) + code_gen += gen_class(class_name, export_template_args(args) + codeBody, inheritance_code) + + return code_gen + + +def gen_template_struct(struct_name, args, codeBody, speicalized = None, set_default = True, export_args = True): + code_gen = "" + code_gen += gen_template_head(args, set_default) + code = export_template_args(args) + codeBody + if export_args is False: + code = codeBody + code_gen += gen_struct(struct_name, code , speicalized) + + return code_gen + + +def gen_declare_template_struct(name, *params): + code = name + "<" + cnt = 0 + param_num = len(params) + for param in params: + final = ", " + if cnt == param_num - 1: + final = "" + code += param + final + cnt += 1 + code += ">;\n" + return code + + +def filtered_param(params, name_and_value_pair, keep_ = False): + rtn_template_args = [] + speicalized_template_args = [] + + for param in params: + param_name = "" + if len(param) >= 1: + param_name = param[1] + else: + param_name = param[0] + + hit_flag = False + set_value = "" + for n_v_pair in name_and_value_pair: + + filter_name = n_v_pair[0] + set_value = n_v_pair[1] + + if param_name == (filter_name + "_") or param_name == filter_name : + hit_flag = True + break + + + if hit_flag is False: + rtn_template_args.append(param) + + if hit_flag is True: + speicalized_template_args.append(set_value) + else: + if keep_ is True: + speicalized_template_args.append(param_name + "_") + else: + speicalized_template_args.append(param_name) + + + specialized_template_arg_str = helper.list_2_string(speicalized_template_args) + + return rtn_template_args, specialized_template_arg_str + + +def gen_func(func_name, arg_lists, code_body, only_declare = False, with_cudaStream = True): + code = "void " + func_name + "(\n" + for arg in arg_lists: + arg_tp = arg[0] + arg_nm = arg[1] + code += " " + arg_tp + " " + arg_nm + ",\n" + code += "cudaStream_t stream)" + if only_declare : + return code + code += "{\n" + + code += code_body + "\n" + code += "}\n" + return code + + +def indent_level(code, level = 0): + rtn_code = "" + for i in range(level): + rtn_code += " " + + rtn_code += code + + return rtn_code diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..af901ac8523a93611763cdc25c7fd06355936afe --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py @@ -0,0 +1,476 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_ir +import helper +import gen_threadblock as gen_tb + + +class gen_default_Gemm: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bGemm" + self.template_param = template_param + self.b2b_num = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_B2bMma(self, specialized_template_args): + code = "using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<\n" + code += specialized_template_args + code += ">::ThreadblockB2bMma;\n" + + # print(code) + return code + + def gen_epilogue(self): + epilogue_code = "" + epilogue_code += helper.var_idx("static const int kPartitionsK", self.b2b_num - 1) + helper.var_idx(" = ThreadblockShape", self.b2b_num - 1) + helper.var_idx("::kK / WarpShape", self.b2b_num - 1) + "::kK;\n" + + epilogue_code += "using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<\n" + epilogue_code += " " + helper.var_idx("ThreadblockShape", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("typename B2bMma::Operator", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("kPartitionsK", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + "::kCount\n" + epilogue_code += ">::Epilogue;\n" + + epilogue_code += "using B2bGemmKernel = kernel::B2bGemm;\n\n" + + return epilogue_code + + + def gen_include_header(self): + code = ''' +/* Auto Generated code - Do not edit.*/ + +#pragma once +#include \"{cutlass_dir}cutlass/cutlass.h\" + +#include \"{cutlass_dir}cutlass/layout/matrix.h\" +#include \"{cutlass_dir}cutlass/numeric_types.h\" + +#include \"{cutlass_dir}cutlass/epilogue/threadblock/epilogue.h\" +#include \"{cutlass_dir}cutlass/epilogue/thread/linear_combination.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/gemm/kernel/gemm_pipelined.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_simt.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/threadblock_swizzle.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_tensor_op.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_simt.h\" + +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\" + +#include \"../kernel/b2b_gemm.h\" +#include \"../threadblock/default_b2b_mma.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + def gen_code(self): + gen_using = '' + # Generate default template struct + gen_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, self.template_param,"", speicalized = None, set_default=False) + + + filter_list = [] + filter_list.append(('Stages', 2)) + filter_list.append(("OperatorClass", "arch::OpClassTensorOp")) + filter_list.append(("ArchTag", "arch::Sm75")) + + for i in range(self.b2b_num): + filter_list.append((helper.var_idx("LayoutC", i), "layout::RowMajor")) + + + rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, filter_list, keep_= True) + + + B2bMma_code = self.gen_B2bMma(speicalized_template_args) + epilogue_and_rest_code = self.gen_epilogue() + + gen_special_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, rtn_template_args, B2bMma_code + epilogue_and_rest_code, speicalized = speicalized_template_args, set_default=False) + + code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", gen_code + gen_special_code))) + + return self.gen_include_header() + code + + +class gen_Kernel: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bGemm" + self.template_param = template_param + self.b2bnum = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/matrix_coord.h\"\n'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + def gen_Params(self): + gen_param = "" + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + ";\n" + gen_param += " " + "cutlass::gemm::GemmCoord grid_tiled_shape;\n" + gen_param += " " + "typename B2bMma::IteratorA0::Params params_A0;\n" + gen_param += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0;\n" + + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::Params params_B", i) + ";\n" + gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ";\n" + if i == self.b2bnum - 1: + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_C", i) + ";\n" + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ";\n" + + else: + gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::Params params_C", i) + ";\n" + gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ";\n" + + + + + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_D", self.b2bnum - 1) + ";\n" + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ";\n" + + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + ";\n" + + gen_param += " " + 'int batch_count' + ";\n" + gen_param += " " + 'int gemm_k_iterations_0' + ";\n" + + + return gen_param + + def gen_Memberfunc(self): + code_default = "\nCUTLASS_HOST_DEVICE\n" + code_default += "Params()" + + code_default += " { } \n\n" + + code_construct = "\nCUTLASS_HOST_DEVICE\n" + code_construct += "Params(\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("cutlass::gemm::GemmCoord const & problem_size_", i) + ",\n" + + code_construct += " " + "cutlass::gemm::GemmCoord const & grid_tiled_shape,\n" + + code_construct += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0,\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ",\n" + if i == self.b2bnum - 1: + code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ",\n" + else: + code_construct += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ",\n" + + code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ",\n" + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + helper.var_idx(" = typename OutputOp", i) + "::Params(),\n" + + code_construct += " " + "int batch_count = 1\n" + + code_construct += "):\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("problem_size_", i) + helper.var_idx("(problem_size_", i) + "),\n" + + code_construct += " " + "grid_tiled_shape(grid_tiled_shape),\n" + code_construct += " " + "params_A0(ref_A0.layout()),\n" + code_construct += " " + "ref_A0(ref_A0),\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("params_B", i) + helper.var_idx("(ref_B", i) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_B", i) + helper.var_idx("(ref_B", i) + "),\n" + code_construct += " " + helper.var_idx("params_C", i) + helper.var_idx("(ref_C", i) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_C", i) + helper.var_idx("(ref_C", i) + "),\n" + + code_construct += " " + helper.var_idx("params_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + "),\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("output_op_", i) + helper.var_idx("(output_op_", i) + "), \n" + + code_construct += " " + "batch_count(batch_count) {\n" + code_construct += " " + helper.var_idx("gemm_k_iterations_", 0) + helper.var_idx(" = (problem_size_", 0) + helper.var_idx(".k() + B2bMma::Shape", 0) + helper.var_idx("::kK - 1) / B2bMma::Shape", 0) + "::kK;\n" + + code_construct += "}\n" + + return code_default + code_construct + + def gen_using(self): + code_using = "" + + for i in range(self.b2bnum - 1): + code_using += " " + helper.var_idx("using OutputOp", i) + helper.var_idx(" = typename B2bMma::OutputOp", i) + ";\n" + + code_using += " " + helper.var_idx("using OutputOp", self.b2bnum - 1) + " = typename Epilogue::OutputOp;\n" + + for i in range(self.b2bnum - 1): + code_using += " " + helper.var_idx("using FusedAddBiasEpilogue", i) + helper.var_idx(" = typename B2bMma::FusedAddBiasEpilogue", i) +";\n" + + + code_using += " " + "using WarpCount0 = typename B2bMma::WarpCount0;\n" + code_using += " " + "static int const kThreadCount = 32 * WarpCount0::kCount;\n" + + code_using += gen_ir.gen_struct("Params", self.gen_Params() + self.gen_Memberfunc()) + + code_using += "union SharedStorage {\n" + code_using += " " + "typename B2bMma::B2bMmaSharedStorage main_loop;\n" + code_using += " " + "typename Epilogue::SharedStorage epilogue;\n" + code_using += "};\n" + + return code_using + + def gen_can_implement(self): + gen_code = "" + return gen_code + + def gen_operator_and_constr(self): + ctr_code = "CUTLASS_HOST_DEVICE\n" + ctr_code += self.gen_class_name + "() { } \n\n" + operator_code = "CUTLASS_DEVICE\n" + operator_code += "void operator()(Params const ¶ms, SharedStorage &shared_storage) {\n" + operator_code += " " + "ThreadblockSwizzle threadblock_swizzle;\n" + operator_code += " " + "cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n" + operator_code += " " + "int batch_idx = threadblock_tile_offset.k();\n" + operator_code += " " + "if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||\n" + operator_code += " " + "params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {\n" + operator_code += " " + " " + "return;\n" + operator_code += " " + "}\n" + + operator_code += " " + "cutlass::MatrixCoord tb_offset_A0{\n" + operator_code += " " + " " + "threadblock_tile_offset.m() * B2bMma::Shape0::kM,\n" + operator_code += " " + " " + "0\n" + operator_code += " " + "};\n" + + for i in range(self.b2bnum): + operator_code += " " + helper.var_idx("cutlass::MatrixCoord tb_offset_B", i) + "{\n" + operator_code += " " + " " + "0,\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", i) + "::kN\n" + operator_code += " " + "};\n" + + operator_code += " " + "int thread_idx = threadIdx.x;\n\n" + + operator_code += " " + "MatrixCoord threadblock_offset(\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.m() * B2bMma::Shape", self.b2bnum - 1) + "::kM,\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", self.b2bnum - 1) + "::kN\n" + operator_code += " " + ");\n" + + operator_code += " " + "typename B2bMma::IteratorA0 iterator_A0(\n" + operator_code += " " + " " + "params.params_A0,\n" + operator_code += " " + " " + "params.ref_A0.data(),\n" + operator_code += " " + " " + "params.problem_size_0.mk(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "tb_offset_A0);\n" + + operator_code += " " + "iterator_A0.add_pointer_offset(batch_idx * params.problem_size_0.m() * params.problem_size_0.k());\n\n" + + + for i in range (self.b2bnum): + operator_code += " " + helper.var_idx("typename B2bMma::IteratorB", i ) + helper.var_idx(" iterator_B", i) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_B", i) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_B", i) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", i) + ".kn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + helper.var_idx("tb_offset_B", i) + ");\n" + operator_code += " " + helper.var_idx("iterator_B", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * params.problem_size_", i) + ".k());\n\n" + + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("typename FusedAddBiasEpilogue", i ) + helper.var_idx("::OutputTileIterator iterator_C", i) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_C", i) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_C", i) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_" , i) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset" + ");\n" + operator_code += " " + helper.var_idx("int ref_C", i) + helper.var_idx("_stride = params.ref_C", i) + ".stride()[0];\n" + operator_code += " " + helper.var_idx("iterator_C", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * (ref_C", i) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", i) + ".m()));\n\n" + + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n" + + + operator_code += " " + "int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);\n" + operator_code += " " + "int lane_idx = threadIdx.x % 32;\n" + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("OutputOp", i) + helper.var_idx(" output_op_", i) + helper.var_idx("(params.output_op_", i) + ");\n" + + operator_code += " " + "B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);\n" + + operator_code += " " + "typename B2bMma::FragmentC0 src_accum;\n" + operator_code += " " + helper.var_idx("typename B2bMma::FragmentC", self.b2bnum - 1)+ " accumulators;\n" + + operator_code += " " + "src_accum.clear();\n" + operator_code += " " + "accumulators.clear();\n" + operator_code += " " + "b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, " + + for i in range(self.b2bnum): + operator_code += helper.var_idx("iterator_B", i) + ", " + + operator_code += "src_accum" + if self.b2bnum != 1: + operator_code += ", " + for i in range(self.b2bnum - 1): + operator_code += helper.var_idx("output_op_", i) + ", " + + for i in range(self.b2bnum - 1): + operator_code += helper.var_idx("epilogue_", i) + ", " + + for i in range(self.b2bnum - 1): + final = ", " + if i == self.b2bnum - 2: + final ="" + operator_code += helper.var_idx("iterator_C", i) + final + operator_code += ");\n" + + operator_code += " " + helper.var_idx("OutputOp", self.b2bnum - 1) + helper.var_idx(" output_op_", self.b2bnum - 1) + helper.var_idx("(params.output_op_", self.b2bnum - 1) + ");\n" + operator_code += " " + "threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n" + + + + operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_C", self.b2bnum - 1) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_C", self.b2bnum - 1) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_C", self.b2bnum - 1) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset\n" + operator_code += " " + ");\n" + operator_code += " " + helper.var_idx("int ref_C", self.b2bnum - 1) + helper.var_idx("_stride = params.ref_C", self.b2bnum - 1) + ".stride()[0];\n" + + operator_code += " " + helper.var_idx("iterator_C", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * (ref_C", self.b2bnum - 1) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", self.b2bnum - 1) + ".m()));\n\n" + + operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_D", self.b2bnum - 1) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_D", self.b2bnum - 1) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_D", self.b2bnum - 1) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset\n" + operator_code += " " + ");\n" + operator_code += " " + helper.var_idx("iterator_D", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * params.problem_size_", self.b2bnum - 1) + ".m());\n\n" + + + operator_code += " " + "Epilogue epilogue(\n" + operator_code += " " + " " + "shared_storage.epilogue,\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "warp_idx,\n" + operator_code += " " + " " + "lane_idx\n" + operator_code += " " + ");\n" + + operator_code += " " + "epilogue(" + operator_code += helper.var_idx("output_op_", self.b2bnum - 1) + ", " + operator_code += helper.var_idx("iterator_D", self.b2bnum - 1) + ", " + operator_code += "accumulators, " + operator_code += helper.var_idx("iterator_C", self.b2bnum - 1) + ");\n" + operator_code += "}\n" + + return ctr_code + operator_code + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/matrix_coord.h\" +#include \"{cutlass_dir}cutlass/semaphore.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + def gen_code(self): + + template_param = [] + template_param.append(("typename", "B2bMma")) + template_param.append(("typename", "Epilogue")) + template_param.append(("typename", "ThreadblockSwizzle")) + template_param.append((bool, "SplitKSerial")) + + code_body = "" + code_body += self.gen_using() + code_body += self.gen_operator_and_constr() + + struct_code = gen_ir.gen_template_struct(self.gen_class_name, template_param, code_body) + code = self.gen_include_header() + code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", struct_code))) + + return self.gen_include_header() + code + + + +class gen_kernel: + def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root): + self.template_param = template_param + + self.gen_class_name = "B2bGemm" + self.gen_kernel_name = gen_class_name + "Kernel" + self.template_args = [] + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + self.gen_default_b2b_gemm = gen_default_Gemm(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_Kerenl = gen_Kernel(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + + # Include gen_threadBlock + self.gen_threadBlock = gen_tb.gen_threadblock(template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root) + + self.file_dir = output_dir + "/kernel/" + + def gen_code(self, first_use_1stage): + + default_b2b_gemm = self.gen_default_b2b_gemm.gen_code() + + print("[INFO]: Gen kernel code [default_b2b_gemm.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "default_b2b_gemm.h", "w+") as f: + f.write(default_b2b_gemm) + + kernel = self.gen_Kerenl.gen_code() + print("[INFO]: Gen kernel code [b2b_gemm.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_gemm.h", "w+") as f: + f.write(kernel) + + # Call code to gen threadblock + self.gen_threadBlock.gen_code(first_use_1stage) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..75c2b6869364a6d6f73215e470080ce307f4df55 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py @@ -0,0 +1,232 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +class gen_test: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + self.user_header_file = user_header_file + self.sample_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + def gen_cpp_sample(self): + code = "/* Auto Generated code - Do not edit.*/\n" + code += "#include \n" + + code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n" + code += "#include \"cutlass/cutlass.h\" \n" + + code += "#include \"../cutlass_irrelevant.h\" \n" + code += "#include \"../cutlass_verify.h\" \n" + + code += "#include \"leaky_bias.h\" \n" + + code += "#include \"utils.h\" \n" + + + + code += "int main(int args, char * argv[]) {\n" + code += " " + "int M = atoi(argv[1]);\n" + code += " " + "int K0 = " + str(self.fuse_gemm_info[0]['mnk'][0]) + ";\n" + code += " " + "if(args == 3);\n" + code += " " + " " + "K0 = atoi(argv[2]);\n" + code += " " + "int B = 1;\n" + code += " " + "if(args == 4);\n" + code += " " + " " + "B = atoi(argv[3]);\n" + + code += " " + "srand(1234UL);\n" + code += " " + "int device_id = 0;\n" + code += " " + "cudaGetDevice(&device_id);\n" + code += " " + "cudaDeviceProp prop;\n" + code += " " + "cudaGetDeviceProperties(&prop, device_id);\n" + code += " " + "int sm = prop.major *10 + prop.minor;\n" + code += "using ElementCompute = cutlass::half_t;\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("ElementCompute alpha", i) + " = ElementCompute(1);\n" + addbias = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) + if addbias: + code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(1);\n" + else: + code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(0);\n" + + code += " " + "size_t flops = 0;\n" + + for i in range(self.b2b_num): + m = self.fuse_gemm_info[i]['mnk'][0] + n = self.fuse_gemm_info[i]['mnk'][1] + k = self.fuse_gemm_info[i]['mnk'][2] + + bias_shape = helper.get_epilogue_bias_shape(self.fuse_gemm_info[i]) + + this_k = "K0" + if (i > 0): + this_k = str(k) + + code += " " + "flops += size_t(2) * size_t(M) * size_t(B) * " + "size_t(" + str(n) + ") * size_t(" + this_k + ");\n" + + code += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(" + "M" + ", " + str(n) + ", " + this_k + ");\n" + + code += " " + helper.var_idx("memory_unit Mat_A", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".k());\n" + code += " " + helper.var_idx("memory_unit Mat_B", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".n() * problem_size_", i) + ".k());\n" + code += " " + helper.var_idx("memory_unit Mat_C", i) + "(B * " + str(bias_shape[0]) + " * " + str(bias_shape[1]) + ");\n" + code += " " + helper.var_idx("memory_unit Mat_D_cutlass_ref", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".n());\n" + + code += " " + helper.var_idx("Mat_A", i) + ".init();\n" + code += " " + helper.var_idx("Mat_B", i) + ".init();\n" + code += " " + helper.var_idx("Mat_C", i) + ".init();\n" + + + + code += " " + helper.var_idx("memory_unit Mat_D", self.b2b_num - 1) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_",self.b2b_num - 1) + ".n());\n" + + params = [] + params.append("M") + params.append("B") + + params.append("Mat_A0.device_ptr") + for i in range(self.b2b_num): + params.append(helper.var_idx("Mat_B", i) + ".device_ptr") + params.append(helper.var_idx("Mat_C", i) + ".device_ptr") + if i != self.b2b_num-1: + params.append(helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr") + params.append(helper.var_idx("Mat_D", self.b2b_num - 1) + ".device_ptr") + + code += " " + "Param arguments = {\n" + code += " " + " " + "M,\n" + code += " " + " " + "K0,\n" + code += " " + " " + "B,\n" + + code += " " + " " + "reinterpret_cast(Mat_A0.device_ptr),\n" + cnt = 1 + for i in range(self.b2b_num): + bias_flag = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_B", i) + ".device_ptr" + "),\n" + cnt += 1 + if bias_flag: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_C", i) + ".device_ptr" + "),\n" + cnt += 1 + else: + code += " " + " " + "reinterpret_cast(NULL),\n" + + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_value = str(arg[2]) + + code += " " + " " + helper.type_2_cutlass_type(acc_tp) + "(" + arg_value + "),\n" + + if i != self.b2b_num - 1: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr" + "),\n" + else: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_D", i) + ".device_ptr" + ")};\n" + + + + + code += " " + "TI(FUSED_CUTLASS);\n" + code += " " + "for(int i = 0; i < 100; i++){\n" + code += " " + " " + "one_api(arguments, sm, NULL);\n" + + code += " " + "}\n" + code += " " + "TO(FUSED_CUTLASS, \"FUSED_CUTLASS\", 100);\n" + + code += "\n" + + for i in range(self.b2b_num): + code_this = "" + + N_str = str(self.fuse_gemm_info[i]['mnk'][1]) + + code_this += " " + helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n" + code_this += " " + " " + helper.var_idx("problem_size_", i) + ",\n" + ldmA = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmA = "K0" + ldmB = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmB = "K0" + ldmC = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + if self.fuse_gemm_info[i]['A_format'] is 'Col': + ldmA = "M" + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + if self.fuse_gemm_info[i]['C_format'] is 'Col': + ldmC = "M" + + if i == 0: + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_A", i) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n" + else: + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i - 1) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n" + + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("Mat_B", i) + ".device_ptr), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n" + + M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0]) + + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_C", i) + ".device_ptr), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n" + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr), " + ldmC + "}, " + "M * " + ldmC + ",\n" + code_this += " " + " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_value = str(epilogue_arg[2]) + code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_value) + ")" + code_this += " " + " },\n" + code_this += " " + " " + "B};\n" + + code += code_this + + + + code += " " + "TI(UNFUSED_CUTLASS);\n" + code += " " + "for(int i = 0; i < 100; i++){\n" + code += " " + " " + self.gen_class_name + "_verify(\n" + for i in range(self.b2b_num): + code += " " + " " + " " + helper.var_idx("arguments_", i) + ",\n" + code += " " + " " + " " + "NULL);\n" + + code += " " + "}\n" + code += " " + "TO(UNFUSED_CUTLASS, \"UNFUSED_CUTLASS\", 100);\n" + + code += " " + helper.var_idx("Mat_D_cutlass_ref", self.b2b_num - 1) + ".d2h();\n" + code += " " + helper.var_idx("Mat_D", self.b2b_num - 1) + ".d2h();\n" + code += " " + helper.var_idx("check_result(Mat_D_cutlass_ref", self.b2b_num - 1) + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) \ + + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) + ".elements);\n" + + code += "\n\n}\n" + + with open(self.sample_dir + "sample.cu", "w+") as f: + f.write(code) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py new file mode 100644 index 0000000000000000000000000000000000000000..6dab42b29f499b85a6d7052c41442ad71ba25925 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py @@ -0,0 +1,1013 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_ir +import helper + + +class gen_default_b2b_mma: + def __init__(self, template_param, gen_class_name, b2b_num,cutlass_deps_root, project_root): + self.gen_class_name = "DefaultB2bMma" + self.template_param = template_param + self.b2b_num = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +/* Auto Generated code - Do not edit.*/ + +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/numeric_types.h\" +#include \"{cutlass_dir}cutlass/arch/arch.h\" + +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\" +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\" + +#include \"../threadblock/b2b_mma_pipelined.h\" +#include \"../../fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h\" +#include \"../../fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h\" +#include \"../../fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + + def gen_using_MmaCore(self, stage): + threadBlockShape = "ThreadblockShape" + warpShape = "WarpShape" + instrunctionShape = "InstructionShape" + Mma_typename = "typename cutlass::gemm::threadblock::DefaultMmaCore" + + + gen_code = "" + + for i in range(self.b2b_num): + code_using = "using MmaCore" + str(i) + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(Mma_typename, \ + helper.var_idx(threadBlockShape, i), helper.var_idx(warpShape, i), instrunctionShape, \ + "ElementA", "LayoutA", \ + helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), \ + helper.var_idx("ElementAccumulator", i), "layout::RowMajor", \ + "OperatorClass", str(stage), "Operator") + return gen_code + + def gen_using_FusedAddBiasEpilogue(self): + gen_code = "" + for i in range(self.b2b_num - 1): + code_using = helper.var_idx("using FusedAddBiasEpilogue", i) + epilogue_name = "typename cutlass::epilogue::threadblock::DefaultFusedBiasActEpilogueTensorOp" + template_args = helper.var_idx("::Epilogue" + + gen_code += code_using + " = " + epilogue_name + template_args + ";\n" + + return gen_code + + + def gen_using_Iterator(self): + code_using = "using IteratorA0" + iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator" + MmaCore = "MmaCore0" + matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kM, " + MmaCore + "::Shape::kK>" + iterator_map = "typename " + MmaCore + "::IteratorThreadMapA" + gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + matrix_shape, "ElementA", "LayoutA", "1", iterator_map, "AlignmentA_") + + for i in range(self.b2b_num): + code_using = "using IteratorB" + str(i) + iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator" + MmaCore = "MmaCore" + str(i) + matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kK, " + MmaCore + "::Shape::kN>" + iterator_map = "typename " + MmaCore + "::IteratorThreadMapB" + + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + matrix_shape, helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), "0", iterator_map, "AlignmentB_") + + return gen_code + + def gen_fragment_iterator(self): + gen_code = "using AccumulatorLayout = cutlass::layout::ColumnMajor;\n" + + for i in range(1, self.b2b_num): + code_using = "using FragmentIteratorA" + str(i) + iterator_typename = "cutlass::gemm::warp::MmaTensorOpPureFragmentIterator" + curr_MmaCore = "MmaCore" + str(i) + prev_MmaCore = "MmaCore" + str(i - 1) + Matrix_shape_curr = "cutlass::MatrixShape<" + curr_MmaCore + "::WarpShape::kM, " + curr_MmaCore + "::InstructionShape::kK>" + Matrix_shape_prev = "cutlass::MatrixShape<" + prev_MmaCore + "::WarpShape::kM, " + prev_MmaCore + "::WarpShape::kN>" + Curr_shape_kK = curr_MmaCore + "::Shape::kK" + + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + Matrix_shape_curr, Matrix_shape_prev, Curr_shape_kK, \ + helper.var_idx("ElementAccumulator", i-1), "ElementA", \ + "AccumulatorLayout", "InstructionShape_", "true") + + return gen_code + + def gen_threadblockmma(self): + code_using = "using ThreadblockB2bMma" + iterator_typename = "cutlass::gemm::threadblock::B2bMmaPipelined" + + MmaPipelined_param_Mma0_shape = "typename MmaCore0::Shape" + MmaPipelined_param_Mma0_iteratorA = "IteratorA0" + MmaPipelined_param_Mma0_smemIteratorA = "typename MmaCore0::SmemIteratorA" + MmaPipelined_param_Mma0_iteratorB = "IteratorB0" + MmaPipelined_param_Mma0_smemIteratorB = "typename MmaCore0::SmemIteratorB" + + MmaPipelined_param_list = MmaPipelined_param_Mma0_shape + ", " + MmaPipelined_param_Mma0_iteratorA + ", " + MmaPipelined_param_Mma0_smemIteratorA + ", " + MmaPipelined_param_Mma0_iteratorB + ", " + MmaPipelined_param_Mma0_smemIteratorB + ", " + + for i in range(1, self.b2b_num): + MmaPipelined_param_Mma_shape = "typename MmaCore" + str(i) + "::Shape" + MmaPipelined_param_Mma_iteratorA = "FragmentIteratorA" + str(i) + MmaPipelined_param_Mma_iteratorB = "IteratorB" + str(i) + MmaPipelined_param_Mma_smemIteratorB = "typename MmaCore" + str(i) + "::SmemIteratorB" + + MmaPipelined_param_list += MmaPipelined_param_Mma_shape + ", " + MmaPipelined_param_Mma_iteratorA + ", " + MmaPipelined_param_Mma_iteratorB + ", " + MmaPipelined_param_Mma_smemIteratorB + ", " + + MmaPipelined_param_list += "ElementAccumulator0, layout::RowMajor, " + + for i in range(self.b2b_num - 1): + epilogue_name = "EpilogueOutputOp" + str(i) + MmaPipelined_param_list += epilogue_name + ", " + + for i in range(self.b2b_num - 1): + epilogue_name = "FusedAddBiasEpilogue" + str(i) + MmaPipelined_param_list += epilogue_name + ", " + + for i in range(self.b2b_num): + MmaPolicy = "typename MmaCore" + str(i) + "::MmaPolicy" + MmaPipelined_param_list += MmaPolicy + ", " + + + cnt = 0 + for i in range(self.b2b_num): + MmaStage = helper.var_idx("Stages", i) + final = ", " + if cnt == self.b2b_num - 1: + final = "" + MmaPipelined_param_list += MmaStage + final + cnt += 1 + + gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, MmaPipelined_param_list) + + return gen_code + + + + def gen_code(self): + gen_using = '' + # Generate default template struct + gen_code = gen_ir.gen_template_struct(self.gen_class_name, self.template_param, "", speicalized = None, set_default=False) + + # Generate specialized template struct + + mmacore_codebody = self.gen_using_MmaCore(2) + iterator_codebody = self.gen_using_Iterator() + fragment_iterator_codebody = self.gen_fragment_iterator() + epilogue_iterator_codebody = self.gen_using_FusedAddBiasEpilogue() + threadBlockMma = self.gen_threadblockmma() + specialized_code = mmacore_codebody + iterator_codebody + fragment_iterator_codebody + epilogue_iterator_codebody + threadBlockMma + + # Specialize layout C -> cutlass::layout::RowMajor + + rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, [ ('LayoutD', "cutlass::layout::RowMajor")], keep_= True) + + gen_speical_code = gen_ir.gen_template_struct(self.gen_class_name, rtn_template_args, specialized_code, speicalized = speicalized_template_args, set_default=False) + code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", gen_code + gen_speical_code))) + + return self.gen_include_header() + code + + +class gen_b2b_mme_pipelined: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bMmaPipelined" + self.template_param = template_param + self.b2b_num = b2b_num + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/array.h\" +#include \"{cutlass_dir}cutlass/aligned_buffer.h\" +#include \"{cutlass_dir}cutlass/numeric_conversion.h\" + +#include \"{cutlass_dir}cutlass/numeric_types.h\" +#include \"{cutlass_dir}cutlass/matrix_shape.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h\" + +#include \"../threadblock/b2b_mma_base.h\"\n'''.format(cutlass_dir = self.cutlass_deps_root) + return code + + + def gen_using(self): + code_using = "using FragmentA0 = typename IteratorA0::Fragment;\n" + + code_using += "using Base = B2bMmaBase<" + for i in range(self.b2b_num): + code_using += helper.var_idx("Shape", i) + "_, " + for i in range(self.b2b_num): + code_using += helper.var_idx("Policy", i) + "_, " + for i in range(self.b2b_num): + code_using += helper.var_idx("Stage", i) + "_, " + code_using = code_using[: -2] + ">;\n" + + + for i in range(self.b2b_num): + code_using += helper.var_idx("using FragmentB", i) + helper.var_idx(" = typename IteratorB", i) + "::Fragment;\n" + code_using += helper.var_idx("using FragmentC", i) + helper.var_idx(" = typename Policy", i) + "::Operator::FragmentC;\n" + code_using += helper.var_idx("using Operator", i) + helper.var_idx(" = typename Policy", i) + "::Operator;\n" + + for i in range(self.b2b_num - 1): + code_using += helper.var_idx("using IteratorC", i) + helper.var_idx(" = typename FusedAddBiasEpilogue", i) + "::OutputTileIterator;\n" + + code_using += "using ArchTag = typename Policy0::Operator::ArchTag;\n" + code_using += "static ComplexTransform const kTransformA0 = Operator0::kTransformA;\n" + + for i in range(self.b2b_num): + code_using += helper.var_idx("static ComplexTransform const kTransformB", i) + helper.var_idx(" = Operator", i) + "::kTransformB;\n" + + code_using += "private:\n" + code_using += "using WarpFragmentA0 = typename Operator0::FragmentA;\n" + code_using += "using WarpFragmentB0 = typename Operator0::FragmentB;\n" + + for i in range(1, self.b2b_num): + code_using += helper.var_idx("using WarpFragmentA", i) + helper.var_idx(" = typename FragmentIteratorA", i) + "::Fragment;\n" + code_using += helper.var_idx("using WarpFragmentB", i) + helper.var_idx(" = typename Operator", i) + "::FragmentB;\n" + + code_using += "protected:\n" + + code_using += "SmemIteratorA0 smem_iterator_A_;\n" + + for i in range(self.b2b_num): + code_using += helper.var_idx("SmemIteratorB", i) + helper.var_idx(" smem_iterator_B", i) + "_;\n" + + return code_using + + + def gen_operator(self, first_use_1stage = False): + code = "" + def gen_operator_param(b2b_num): + param_code = "" + param_code += "int gemm_k_iterations_0,\n" + param_code += helper.var_idx("FragmentC", b2b_num-1) + helper.var_idx(" &accum", b2b_num-1) + ",\n" + param_code += "IteratorA0 iterator_A,\n" + + for i in range(b2b_num): + param_code += helper.var_idx("IteratorB", i) + " " + helper.var_idx("iterator_B", i) + ",\n" + + param_code += "FragmentC0 const &src_accum, \n" + + for i in range(b2b_num - 1): + param_code += helper.var_idx("OutputOp", i) + " " + helper.var_idx("output_op_", i) + ",\n" + for i in range(b2b_num - 1): + param_code += helper.var_idx("FusedAddBiasEpilogue", i) + " " + helper.var_idx("epilogue_", i) + ",\n" + for i in range(b2b_num - 1): + param_code += helper.var_idx("IteratorC", i) + " " + helper.var_idx("iterator_C", i) + ",\n" + + + param_code += "TransformA0 transform_A0 = TransformA0(), \n" + + for i in range(b2b_num): + final = "(),\n" + if i == b2b_num - 1: + final = "()\n" + param_code += helper.var_idx("TransformB", i) + " " + helper.var_idx("transform_B", i) + " = " +helper.var_idx("TransformB", i) + final + + return param_code + + + + def gen_first_gemm_1stage(b2b_num): + accu_code = " FragmentC0 accum0 = src_accum;\n" + if b2b_num == 1: + accu_code = " accum0 = src_accum;\n" + + code ="\ +\n\ + FragmentA0 tb_frag_A;\n\ + FragmentB0 tb_frag_B0;\n\ +\n\ + int smem_write_stage_idx = 1;\n\ +\n\ + tb_frag_A.clear();\n\ + tb_frag_B0.clear();\n\ +\n\ + // The last kblock is loaded in the prolog\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + WarpFragmentA0 warp_frag_A0;\n\ + WarpFragmentB0 warp_frag_B0;\n\ +\n\ + Operator0 warp_mma0;\n\ +\n\ + // Avoid reading out of bounds\n\ + if (gemm_k_iterations_0 <= 1) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tightest latency requirement).\n\ +\n\ + //\n\ + // Mainloop\n\ + //\n\ +\n\ + // Note: The main loop does not support Base::WarpGemmIterations == 2.\n\ + CUTLASS_GEMM_LOOP\n\ + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\ +\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + __syncthreads();\n\ + //\n\ + // Loop over GEMM K dimension\n\ + //\n\ +\n\ + CUTLASS_PRAGMA_UNROLL\n\ + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\ +\n\ + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\ + // as the case may be.\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\ +\n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + warp_mma0(accum0, warp_frag_A0, warp_frag_B0, accum0);\n\ + }\n\ + this->warp_tile_iterator_A0_.add_tile_offset({0, -Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\ + this->warp_tile_iterator_B0_.add_tile_offset({-Policy0::kPartitionsK * Base::kWarpGemmIterations0, 0});\n\ +\n\ + __syncthreads();\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + if(gemm_k_iterations_0 <= 2) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ + }\n" + + return accu_code + code + + + def gen_first_gemm_2stage(b2b_num): + + accu_code = " FragmentC0 accum0 = src_accum;\n" + if b2b_num == 1: + accu_code = " accum0 = src_accum;\n" + + code ="\ +\n\ + FragmentA0 tb_frag_A;\n\ + FragmentB0 tb_frag_B0;\n\ +\n\ + tb_frag_A.clear();\n\ + tb_frag_B0.clear();\n\ +\n\ + // The last kblock is loaded in the prolog\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + ++this->smem_iterator_A_;\n\ + ++this->smem_iterator_B0_;\n\ +\n\ + __syncthreads();\n\ +\n\ + // Pair of fragments used to overlap shared memory loads and math instructions\n\ + WarpFragmentA0 warp_frag_A0[2];\n\ + WarpFragmentB0 warp_frag_B0[2];\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index(0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index(0);\n\ +\n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + Operator0 warp_mma0;\n\ +\n\ + int smem_write_stage_idx = 1;\n\ +\n\ + // Avoid reading out of bounds\n\ + if (gemm_k_iterations_0 <= 1) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tightest latency requirement).\n\ + iterator_A.load(tb_frag_A);\n\ +\n\ + //\n\ + // Mainloop\n\ + //\n\ +\n\ + // Note: The main loop does not support Base::WarpGemmIterations == 2.\n\ + CUTLASS_GEMM_LOOP\n\ + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\ +\n\ + //\n\ + // Loop over GEMM K dimension\n\ + //\n\ +\n\ + CUTLASS_PRAGMA_UNROLL\n\ + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\ +\n\ + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\ + // as the case may be.\n\ +\n\ + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {\n\ +\n\ + // Write fragments to shared memory\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ +\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + __syncthreads();\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tightest latency requirement).\n\ + iterator_A.load(tb_frag_A);\n\ + \n\ + ++this->smem_iterator_B0_;\n\ + ++this->smem_iterator_A_;\n\ + \n\ +\n\ + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory\n\ + if (smem_write_stage_idx == 1) {\n\ + this->smem_iterator_A_.add_tile_offset({0, -Base::Stage0});\n\ + this->smem_iterator_B0_.add_tile_offset({-Base::Stage0, 0});\n\ + }\n\ + else {\n\ + this->warp_tile_iterator_A0_.add_tile_offset(\n\ + {0, -Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\ + this->warp_tile_iterator_B0_.add_tile_offset(\n\ + {-Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0,\n\ + 0});\n\ + }\n\ +\n\ + smem_write_stage_idx ^= 1;\n\ + }\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\ + \n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + if (warp_mma_k == 0) {\n\ +\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + // Avoid reading out of bounds if this was the last loop iteration\n\ + if (gemm_k_iterations_0 <= 2) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ + }\n\ +\n\ + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);\n\ + }\n\ + }\n" + return accu_code + code + + def gen_other_gemms_2stage(b2b_num): + + code = "" + + def gemm_teamplate(id): + code = "// " + str(id + 1) + " Gemm" + code += " /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile\n" + + code += " " + helper.var_idx("FragmentC", id - 1) + helper.var_idx(" after_epilogue_accu", id - 1) + ";\n" + code += " " + helper.var_idx("epilogue_", id - 1) + helper.var_idx("(output_op_", id - 1) + helper.var_idx(", accum", id - 1) \ + + helper.var_idx(", after_epilogue_accu", id - 1) + helper.var_idx(", iterator_C", id - 1) +");\n" + + # FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + code += " " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx(" warp_tile_iterator_A", id) +"_(" + helper.var_idx("after_epilogue_accu", id - 1) + ");\n" + # FragmentB1 tb_frag_B1; + code += " " + helper.var_idx("FragmentB", id) + " " + helper.var_idx("tb_frag_B", id) + ";\n" + # tb_frag_B1.clear(); + code += " " + helper.var_idx("tb_frag_B", id) + ".clear();\n" + # iterator_B1.load(tb_frag_B1); + code += " " + helper.var_idx("iterator_B", id) + ".load(" + helper.var_idx("tb_frag_B", id) + ");\n" + # ++iterator_B1; + code += " " + "++" + helper.var_idx("iterator_B", id) + ";\n" + # this->smem_iterator_B1_.store(tb_frag_B1); + code += " " + helper.var_idx("this->smem_iterator_B", id) + "_.store(" + helper.var_idx("tb_frag_B", id) + ");\n" + # ++this->smem_iterator_B1_; + code += " " + helper.var_idx("++this->smem_iterator_B", id) + "_;\n" + # __syncthreads(); + code += " " + "__syncthreads();\n" + # WarpFragmentA1 warp_frag_A1[2]; + code += " " + helper.var_idx("WarpFragmentA", id) + helper.var_idx(" warp_frag_A", id) + "[2];\n" + # WarpFragmentB1 warp_frag_B1[2]; + code += " " + helper.var_idx("WarpFragmentB", id) + helper.var_idx(" warp_frag_B", id) + "[2];\n" + # this->warp_tile_iterator_B1_.set_kgroup_index(0); + code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.set_kgroup_index(0);\n" + # warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0); + code += " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[0]);\n" + # this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[0]);\n" + # ++warp_tile_iterator_A1_; + code += " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n" + # ++this->warp_tile_iterator_B1_; + code += " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n" + # Operator1 warp_mma1; + code += " " + helper.var_idx("Operator", id) + " " + helper.var_idx("warp_mma", id) + ";\n" + # smem_write_stage_idx = 1; + code += " " + "smem_write_stage_idx = 1;\n" + # int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + code += " " + helper.var_idx("int gemm_k_iterations_", id) + " = " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx("::Policy::kIterations / Base::kWarpGemmIterations", id) +";\n" + # if (gemm_k_iterations_1 <= 1) { + # iterator_B1.clear_mask(); + # } + code += " " + "if (" + helper.var_idx("gemm_k_iterations_", id) + " <= 1 ){\n" \ + + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \ + + " " +"}\n" + # CUTLASS_PRAGMA_UNROLL + code += " " + "CUTLASS_PRAGMA_UNROLL\n" + # for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + code += " " + helper.var_idx("for (; gemm_k_iterations_", id) + helper.var_idx(" > 0; --gemm_k_iterations_", id) + ") {\n" + # CUTLASS_PRAGMA_UNROLL + code += " " + " " + "CUTLASS_PRAGMA_UNROLL\n" + # for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + code += " " + " " + helper.var_idx("for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations", id) + "; ++warp_mma_k) {\n" + # if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + code += " " + " " + " " + helper.var_idx("if (warp_mma_k == Base::kWarpGemmIterations", id) + " - 1) {\n" + # this->smem_iterator_B1_.store(tb_frag_B1); + code += " " + " " + " " + " " + helper.var_idx(" this->smem_iterator_B", id) + helper.var_idx("_.store(tb_frag_B", id) + ");\n" + # __syncthreads(); + code += " " + " " + " " + " " + "__syncthreads();\n" + # ++smem_iterator_B1_; + code += " " + " " + " " + " " + helper.var_idx(" ++smem_iterator_B", id) + "_;\n" + # if (smem_write_stage_idx == 1) { + # smem_iterator_B1_.add_tile_offset({-Base::Stage, 0}); + # } + code += " " + " " + " " + " " + "if ( smem_write_stage_idx == 1 ) {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("smem_iterator_B", id) + helper.var_idx("_.add_tile_offset({-Base::Stage", i) + ", 0});\n" \ + + " " + " " + " " + " " +"}\n" + # else { + # this->warp_tile_iterator_B1_.add_tile_offset( + # {-Base::Stage * Policy1::kPartitionsK * + # Base::kWarpGemmIterations1, + # 0}); + # } + code += " " + " " + " " + " " + "else {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.add_tile_offset(\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("{-Base::Stage", id) + helper.var_idx(" * Policy", id) + "::kPartitionsK *\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("Base::kWarpGemmIterations", id) + ",\n" \ + + " " + " " + " " + " " + " " + "0});\n" \ + + " " + " " + " " + " " + "}\n" + + # smem_write_stage_idx ^= 1; + # } + code += " " + " " + " " + " " + "smem_write_stage_idx ^= 1;\n" \ + + " " + " " + " " + "}\n" + + # this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations", id) + ");\n" + # warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0); + code += " " + " " + " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[(warp_mma_k + 1) % 2]);\n" + # this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[(warp_mma_k + 1) % 2]);\n" + # ++warp_tile_iterator_A1_; + code += " " + " " + " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n" + # ++this->warp_tile_iterator_B1_; + code += " " + " " + " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n" + # if (warp_mma_k == 0) { + # iterator_B1.load(tb_frag_B1); + # ++iterator_B1; + # if (gemm_k_iterations_1 <= 2) { + # iterator_B1.clear_mask(); + # } + # } + code += " " + " " + " " + " if (warp_mma_k == 0) {\n" \ + + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + helper.var_idx(".load(tb_frag_B", id) + ");\n" \ + + " " + " " + " " + " " + helper.var_idx("++iterator_B", id) +";\n" \ + + " " + " " + " " + " " + helper.var_idx("if (gemm_k_iterations_", id) +" <= 2) {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \ + + " " + " " + " " + " " + "}\n" \ + + " " + " " + " " + "}\n" + # warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum); + # } + # } + code += " " + " " + " " + helper.var_idx("warp_mma", id) + helper.var_idx("(accum", id) + helper.var_idx(", warp_frag_A", id) + helper.var_idx("[warp_mma_k % 2], warp_frag_B", id) + helper.var_idx("[warp_mma_k % 2], accum", id) + ");\n" \ + + " " + " " + "}\n" \ + + " " + "}\n\n\n" + + return code + + for i in range (1, b2b_num): + clear_accu = "" + if i != b2b_num - 1: + clear_accu = " " + helper.var_idx("FragmentC", i) + helper.var_idx(" accum", i) +";\n" + clear_accu += " " + helper.var_idx("accum", i) +".clear();\n" + code += clear_accu + gemm_teamplate(i) + + return code + + operator_code = " CUTLASS_DEVICE\n\ + void operator()(\n " + gen_operator_param(self.b2b_num) + ") {\n" + if first_use_1stage: + operator_code += gen_first_gemm_1stage(self.b2b_num) + else: + operator_code += gen_first_gemm_2stage(self.b2b_num) + operator_code += gen_other_gemms_2stage(self.b2b_num) + "}\n" + return operator_code + + def gen_construct_func(self): + name = self.gen_class_name + func_code = "CUTLASS_DEVICE\n" + func_code += name + "(\n" \ + + " " + "typename Base::B2bMmaSharedStorage &shared_storage,\n" \ + + " " + "int thread_idx,\n" \ + + " " + "int warp_idx,\n" \ + + " " + "int lane_idx\n" \ + + "):\n" + func_code += " " + "Base(shared_storage, thread_idx, warp_idx, lane_idx),\n" \ + + " " + "smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),\n" + + for i in range(self.b2b_num): + final = ",\n" + if i == self.b2b_num - 1: + final = " {\n" + func_code += helper.var_idx("smem_iterator_B", i) + helper.var_idx("_(shared_storage.sharedStorage", i) +".operand_B_ref(), thread_idx)" + final + + func_code += " " + "int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);\n" + func_code += " " + "int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);\n" + + func_code += " " + "int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;\n" + func_code += " " + "int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;\n" + + for i in range(self.b2b_num): + func_code += " " + helper.var_idx("int tile_offset_k", i) + helper.var_idx(" = Base::kWarpGemmIterations", i) + " * warp_idx_k;\n" + + func_code += " " + "this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k0});\n" + + for i in range(self.b2b_num): + func_code += " " + helper.var_idx("this->warp_tile_iterator_B", i) + helper.var_idx("_.add_tile_offset({tile_offset_k", i) + ", warp_idx_n});\n" + + func_code += "}\n" + + return func_code + + def gen_member_func(self, first_use_1stage): + code = "public:\n" + code += self.gen_operator(first_use_1stage) + code += self.gen_construct_func() + + return code + + def gen_code(self, first_use_1stage): + + def gen_template_args(b2b_num): + template_param = [] + template_param.append(("typename", "Shape0")) + template_param.append(("typename", "IteratorA0")) + template_param.append(("typename", "SmemIteratorA0")) + template_param.append(("typename", "IteratorB0")) + template_param.append(("typename", "SmemIteratorB0")) + + for i in range(1, b2b_num): + template_param.append(("typename", helper.var_idx("Shape", i))) + template_param.append(("typename", helper.var_idx("FragmentIteratorA", i))) + template_param.append(("typename", helper.var_idx("IteratorB", i))) + template_param.append(("typename", helper.var_idx("SmemIteratorB", i))) + + template_param.append(("typename", "ElementC")) + template_param.append(("typename", "LayoutC")) + + for i in range(0, b2b_num - 1): + template_param.append(("typename", helper.var_idx("OutputOp", i))) + + for i in range(0, b2b_num - 1): + template_param.append(("typename", helper.var_idx("FusedAddBiasEpilogue", i))) + + for i in range(0, b2b_num): + template_param.append(("typename", helper.var_idx("Policy", i))) + for i in range(0, b2b_num): + template_param.append((int, helper.var_idx("Stage", i))) + + template_param.append(("typename","TransformA0", "NumericArrayConverter")) + + for i in range(0, b2b_num): + cvtr = helper.var_idx("NumericArrayConverter" + template_param.append(("typename", helper.var_idx("TransformB", i), cvtr)) + + template_param.append(("typename", "Enable", "bool")) + + return template_param + + template_param = gen_template_args(self.b2b_num) + inheritance_code = "public B2bMmaBase<" + for i in range(self.b2b_num): + inheritance_code += helper.var_idx("Shape", i) + "_, " + for i in range(self.b2b_num): + inheritance_code += helper.var_idx("Policy", i) + "_, " + for i in range(self.b2b_num - 1): + inheritance_code += helper.var_idx("Stage", i) + "_, " + inheritance_code += helper.var_idx("Stage", self.b2b_num - 1) + "_" + inheritance_code += ">" + + code_body = "" + using_code= self.gen_using() + func_code = self.gen_member_func(first_use_1stage) + + code_body = using_code + func_code + + class_code = gen_ir.gen_template_class(self.gen_class_name, template_param, code_body, inheritance_code = inheritance_code) + + code = self.gen_include_header() + code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code))) + # print(code) + return code + + +class gen_b2b_mma_base: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = gen_class_name + self.template_param = template_param + self.b2b_num = b2b_num + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dirs}cutlass/aligned_buffer.h\" +#include \"{cutlass_dirs}cutlass/arch/memory.h\" +#include \"{cutlass_dirs}cutlass/array.h\" +#include \"{cutlass_dirs}cutlass/cutlass.h\" +#include \"{cutlass_dirs}cutlass/gemm/gemm.h\" +#include \"{cutlass_dirs}cutlass/matrix_shape.h\" +#include \"{cutlass_dirs}cutlass/numeric_types.h\"\n'''.format(cutlass_dirs=self.cutlass_deps_root) + return code + + def gen_shared_storage(self): + code = \ +" template< \n\ + typename Shape_,\n\ + typename Policy_,\n\ + int ThisStage_\n\ +>\n\ +class SharedStorage {\n\ +public:\n\ + using Shape = Shape_;\n\ + using Policy = Policy_;\n\ + static int const ThisStage = ThisStage_;\n\ + using Operator = typename Policy::Operator;\n\ + \ + using TensorRefA = TensorRef;\n\ + \ + /// Tensor reference to the B operand \n\ + using TensorRefB = TensorRef;\n\ +\n\ + /// Shape of the A matrix operand in shared memory \n\ + using ShapeA = MatrixShape;\n\ +\n\ + /// Shape of the B matrix operand in shared memory\n\ + using ShapeB =\n\ + MatrixShape;\n\ +\n\ + public:\n\ +\n\ + /// Buffer for A operand\n\ + AlignedBuffer operand_A;\n\ +\n\ + /// Buffer for B operand\n\ + AlignedBuffer operand_B;\n\ +\n\ + public:\n\ +\n\ + /// Returns a layout object for the A matrix\n\ + CUTLASS_DEVICE\n\ + static typename Operator::LayoutA LayoutA() {\n\ + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});\n\ + }\n\ +\n\ + /// Returns a layout object for the B matrix\n\ + CUTLASS_HOST_DEVICE\n\ + static typename Operator::LayoutB LayoutB() {\n\ + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});\n\ + }\n\ +\n\ + /// Returns a TensorRef to the A operand\n\ + CUTLASS_HOST_DEVICE\n\ + TensorRefA operand_A_ref() {\n\ + return TensorRefA{operand_A.data(), LayoutA()};\n\ + }\n\ +\n\ + /// Returns a TensorRef to the B operand\n\ + CUTLASS_HOST_DEVICE\n\ + TensorRefB operand_B_ref() {\n\ + return TensorRefB{operand_B.data(), LayoutB()};\n\ + }\n\ + CUTLASS_HOST_DEVICE\n\ + void * get_B_Shared_ptr() {\n\ + return operand_B.data();\n\ + }\n\ + };\n" + return code + + def gen_using_and_misc(self, b2b_num): + code_using = "" + for i in range(b2b_num): + code_using += "using Operator" +str(i) + " = typename Policy" + str(i) +"::Operator;\n" + + for i in range(b2b_num): + code_using += "using WarpGemm" +str(i) + " = typename Policy" + str(i) +"::Operator::Shape;\n" + + for i in range(b2b_num): + code_using += "using WarpCount" +str(i) + " = GemmShape<" + helper.var_idx("Shape", i) +"::kM / " + helper.var_idx("WarpGemm", i) +"::kM, "\ + + helper.var_idx("Shape", i) +"::kN / " + helper.var_idx("WarpGemm", i) +"::kN, "\ + + helper.var_idx("Shape", i) +"::kK / " + helper.var_idx("WarpGemm", i) +"::kK>;\n" + + code_misc = "" + for i in range(b2b_num): + code_misc += "static int const " + helper.var_idx("kWarpGemmIterations", i) + " = (" + helper.var_idx("WarpGemm", i) + "::kK / " + helper.var_idx("Operator", i) +"::Policy::MmaShape::kK);\n" + + code = code_using + code_misc + self.gen_shared_storage() + + for i in range(b2b_num): + code += "using " + helper.var_idx("SharedStorage", i) + " = SharedStorage<" + helper.var_idx("Shape", i) + ", " + helper.var_idx("Policy", i) +", " + helper.var_idx("Stage", i) + ">;\n" + + def gen_union_shared_storage(b2b_num): + code = "" + for i in range(b2b_num): + code += " " +helper.var_idx("SharedStorage", i) + " " + helper.var_idx("sharedStorage", i) +";\n" + return code + + code += "union B2bMmaSharedStorage {\n" + gen_union_shared_storage(self.b2b_num) + "};\n" + + for i in range(b2b_num - 1): + code += helper.var_idx("void * C", i) + "_smm_ptr;\n" + + return code + + def gen_protected(self): + code = "\nprotected:\n" + code += "typename Operator0::IteratorA warp_tile_iterator_A0_;\n" + for i in range(self.b2b_num): + code += "typename Operator" +str(i) + "::IteratorB" +" warp_tile_iterator_B" + str(i) + "_;\n" + return code + + def gen_public_member(self): + code = "\npublic:\n" + + code += "CUTLASS_DEVICE\n" + code += \ + "B2bMmaBase(\n" + \ + " B2bMmaSharedStorage & shared_storage,\n" + \ + " int thread_idx,\n" + \ + " int warp_idx,\n" + \ + " int lane_idx\n" + \ + "):\n" + \ + " warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),\n" + for i in range(self.b2b_num): + final = ",\n" + if i == self.b2b_num-1: + final = "\n" + + iterator = " warp_tile_iterator_B" + str(i) + "_" + shared_storage = "shared_storage.sharedStorage" + str(i) + ".operand_B_ref()" + code += iterator + "(" + shared_storage + ", lane_idx)" + final + + + code += "{\n" + for i in range(self.b2b_num - 1): + code += helper.var_idx(" C", i) + helper.var_idx("_smm_ptr = shared_storage.sharedStorage", i) + ".get_B_Shared_ptr();\n" + code += "}\n" + + return code + + def gen_code(self): + + template_arg = [] + for i in range(self.b2b_num): + template_arg.append(("typename", helper.var_idx("Shape", i))) + for i in range(self.b2b_num): + template_arg.append(("typename", helper.var_idx("Policy", i))) + for i in range(self.b2b_num): + template_arg.append((int, helper.var_idx("Stage", i))) + + + + code_body = self.gen_using_and_misc(self.b2b_num) + code_body += self.gen_protected() + code_body += self.gen_public_member() + + class_code = gen_ir.gen_template_class("B2bMmaBase", template_arg, code_body) + + code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code))) + + return code + + +class gen_threadblock: + def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root): + self.gen_class_name = gen_class_name + self.template_param = template_param + self.b2b_num = b2b_num + self.file_dir = output_dir + "/threadblock/" + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + + self.gen_b2b_mma_base = gen_b2b_mma_base(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_b2b_mma_pipelined = gen_b2b_mme_pipelined(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_default_b2b_mma = gen_default_b2b_mma(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + + + def gen_code(self, first_use_1stage): + + base_code = self.gen_b2b_mma_base.gen_code() + print("[INFO]: Gen kernel code [b2b_mma_base.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_mma_base.h", "w+") as f: + f.write(base_code) + pipeline_code = self.gen_b2b_mma_pipelined.gen_code(first_use_1stage = first_use_1stage) + print("[INFO]: Gen kernel code [b2b_mma_pipelined.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_mma_pipelined.h", "w+") as f: + f.write(pipeline_code) + default_code = self.gen_default_b2b_mma.gen_code() + print("[INFO]: Gen kernel code [default_b2b_mma.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "default_b2b_mma.h", "w+") as f: + f.write(default_code) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py new file mode 100644 index 0000000000000000000000000000000000000000..8697ca88ee4cb40056cca06b8e028abd6b491749 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py @@ -0,0 +1,456 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +class gen_turing_impl: + def __init__(self,fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.class_name = gen_class_name + self.gen_class_name = gen_class_name + "_turing_impl" + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + self.gen_turing_unfused = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + def gen_using(self): + code_using = "using b2b_gemm = typename cutlass::gemm::device::" + self.class_name + ";" + + return code_using + "\n" + + def gen_initialize(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n" + beta = "(1)" + + if helper.get_epilogue_add_bias_or_not(self.fuse_gemm_info[i]) is False: + beta = "(0)" + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n" + k_str = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + k_str = "K0" + code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n" + code += code_this + code += "typename b2b_gemm::Arguments arguments{\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("problem_size_", i) + ",\n" + + + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", 0) + "), " + helper.var_idx("problem_size_", 0) + ".k()},\n" + + for i in range(self.b2b_num): + + ldmB = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmB = "K0" + + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmC = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "},\n" + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmC + "},\n" + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", self.b2b_num -1) + "), " + helper.var_idx("problem_size_", self.b2b_num - 1) + ".n()},\n" + + + for i in range(self.b2b_num): + code += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1] + code += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")" + code += "},\n" + code += " " + "Batch};\n\n" + + code += " " "b2b_gemm gemm_op;\n" + code += " " + "gemm_op.initialize(arguments);\n" + return code + "\n" + + + + def gen_run(self): + code = " " + "gemm_op(stream);\n" + + return code + + def gen_wrapper(self): + code_body = "" + + arg_lists = [] + arg_lists.append(["int", "M"]) + arg_lists.append(["int", "K0"]) + arg_lists.append(["int", "Batch"]) + arg_lists.append(["void*", helper.var_idx("A", 0)]) + for i in range(self.b2b_num): + arg_lists.append(["void*", helper.var_idx("B", i)]) + arg_lists.append(["void*", helper.var_idx("C", i)]) + arg_lists.append(["void*", helper.var_idx("D", i)]) + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + arg_lists.append([arg_tp, arg_name]) + + if self.b2b_num == 1: + code_body += self.gen_turing_unfused.gen_using(False) #False -> Turing, True -> Volta + code_body += self.gen_turing_unfused.gen_initialize() + code_body += self.gen_turing_unfused.gen_run() + else: + code_body += self.gen_using() + code_body += self.gen_initialize() + code_body += self.gen_run() + + code = ir.gen_func(self.gen_class_name, arg_lists, code_body) + + return code + + def gen_code(self): + + code = self.gen_wrapper() + helper.write_2_headfile("turing_impl.h", self.output_dir, self.user_header_file + "\n" + code) + +class gen_volta_turing_fuse_act_impl: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + "_volta_impl" + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + def perf_tiling(self, layer_mnk): + mnk = layer_mnk[:] + block_tile = mnk[:] + block_tile[2] = 32 # force the K tile to be 32 + + # M tile gen + block_tile[0] = 32 + + # N tile gen + if mnk[1] > 128: + block_tile[1] = 256 + elif mnk[1] > 64: + block_tile[1] = 128 + elif mnk[1] > 32: + block_tile[1] = 64 + else : + block_tile[1] = 32 + + warp_tile = block_tile[:] + if block_tile[1] == 256: + warp_tile[1] = 64 + elif block_tile[1] == 128: + warp_tile[1] = 32 + elif block_tile[1] == 64: + warp_tile[1] = 32 + else : + warp_tile[1] = 32 + + warp_tile[0] = 32 + + return block_tile, warp_tile + + + def process_epilogue(self, epilogue_tp, n, C_tp, Acc_tp): + epilogue_setted_type = epilogue_tp + cutlass_epilogue_name = "LinearCombinationRelu" + if epilogue_setted_type.lower() == 'leakyrelu': + cutlass_epilogue_name = "LinearCombinationLeakyRelu" + elif epilogue_setted_type.lower() == 'identity': + cutlass_epilogue_name = "LinearCombination" + + + n_mod_8 = n % 4 + N_align_elements = 1 + if n_mod_8 == 0: + N_align_elements = 8 + elif n_mod_8 == 4: + N_align_elements = 4 + elif n_mod_8 == 2 or n_mod_8 == 6: + N_align_elements = 2 + + epilogue_str = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<" + C_tp + ", " + str(N_align_elements) + ", " + Acc_tp + ", " + Acc_tp + ">" + + return epilogue_str + + def gen_using(self, volta = True): + code_using = "" + volta_arch = "cutlass::arch::Sm70" + volta_tc = "cutlass::gemm::GemmShape<8, 8, 4>" + + turing_arch = "cutlass::arch::Sm75" + turing_tc = "cutlass::gemm::GemmShape<16, 8, 8>" + + arch = "" + tc = "" + if volta: + arch = volta_arch + tc = volta_tc + else: + arch = turing_arch + tc = turing_tc + + for i in range(self.b2b_num): + + k = self.fuse_gemm_info[i]['mnk'][2] + + k_mod_8 = k % 4 + ab_ldm = 1 + if k_mod_8 == 0: + ab_ldm = 8 + elif k_mod_8 == 4: + ab_ldm = 4 + elif k_mod_8 == 2 or k_mod_8 == 6: + ab_ldm = 2 + + block_tile, warp_tile = self.perf_tiling(self.fuse_gemm_info[i]['mnk']) + + this_gemm_config = helper.var_idx("using Gemm", i) + " = cutlass::gemm::device::GemmBatched<\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + ",\n" + this_gemm_config += " " + "cutlass::arch::OpClassTensorOp,\n" + this_gemm_config += " " + arch + ",\n" + this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(block_tile[0]) + ", " + str(block_tile[1]) + ", " + str(block_tile[2]) + ">,\n" + this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(warp_tile[0]) + ", " + str(warp_tile[1]) + ", " + str(warp_tile[2]) + ">,\n" + this_gemm_config += " " + tc + ",\n" + this_gemm_config += " " + self.process_epilogue(helper.get_epilogue_tp(self.fuse_gemm_info[i]), self.fuse_gemm_info[i]['mnk'][1], helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']), helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp'])) + ",\n" + this_gemm_config += " " + "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,\n" + this_gemm_config += " " + "2,\n" + this_gemm_config += " " + str(ab_ldm) + ",\n" + this_gemm_config += " " + str(ab_ldm) + ">;\n" + + code_using += this_gemm_config + "\n" + + return code_using + "\n" + + def gen_initialize(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + + N_str = str(self.fuse_gemm_info[i]['mnk'][1]) + + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n" + beta = "(1)" + if helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) is False: + beta = "(0)" + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n" + + k_str = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + k_str = "K0" + code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n" + code_this += helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n" + code_this += " " + helper.var_idx("problem_size_", i) + ",\n" + ldmA = k_str + ldmB = k_str + ldmC = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + if self.fuse_gemm_info[i]['A_format'] is 'Col': + ldmA = "M" + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + if self.fuse_gemm_info[i]['C_format'] is 'Col': + ldmC = "M" + + if i == 0: + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", i) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n" + else: + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("D", i - 1) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n" + + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n" + + M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0]) + + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n" + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", i) + "), " + ldmC + "}, " + "M * " + ldmC + ",\n" + code_this += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1] + code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")" + code_this += " },\n" + code_this += " " + "Batch};\n" + + code_this += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n" + code_this += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(arguments_", i) + ", nullptr);\n" + + code += code_this + "\n" + return code + "\n" + + + def gen_run(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + code_this += " " + helper.var_idx("gemm_op_", i) + "(stream);\n" + + code += code_this + return code + + def gen_wrapper(self): + code_body = "" + + arg_lists = [] + arg_lists.append(["int", "M"]) + arg_lists.append(["int", "K0"]) + arg_lists.append(["int", "Batch"]) + arg_lists.append(["void*", helper.var_idx("A", 0)]) + for i in range(self.b2b_num): + arg_lists.append(["void*", helper.var_idx("B", i)]) + arg_lists.append(["void*", helper.var_idx("C", i)]) + arg_lists.append(["void*", helper.var_idx("D", i)]) + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + arg_lists.append([arg_tp, arg_name]) + code_body += self.gen_using() + code_body += self.gen_initialize() + code_body += self.gen_run() + + code = ir.gen_func(self.gen_class_name, arg_lists, code_body) + + return code + + def gen_code(self): + code = self.gen_wrapper() + helper.write_2_headfile("volta_impl.h", self.output_dir, self.user_header_file + "\n" + code) + +class gen_one_API: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + self.gen_volta = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + self.gen_turing = gen_turing_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + def gen_CUTLASS_irrelevant_API(self): + code = "" + code += "#include \n" + code += "#include \n" + + param_name = "Fused" + str(self.b2b_num) + "xGemm_" + for i in range(self.b2b_num): + param_name += str(self.fuse_gemm_info[i]['mnk'][1]) + "_" + param_name += "Params" + params = "" + params += " " + "int M;\n" + params += " " + "int K0;\n" + params += " " + "int Batch;\n" + params += " " + "const void* A0;\n" + for i in range(self.b2b_num): + params += " " + "const void* " + helper.var_idx("B", i) + ";\n" + params += " " + "const void* " + helper.var_idx("C", i) + ";\n" + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + params += " " + arg_tp + " " + arg_name + ";\n" + params += " " + "void* " + helper.var_idx("D", i) + ";\n" + code += ir.gen_struct(param_name, params) + code += "using Param = " + param_name + ";\n" + code += "void one_api( const Param & param, int sm, cudaStream_t stream);\n" + + + return code + + def gen_one_api(self): + code = "" + code += "/* Auto Generated code - Do not edit.*/\n" + code += "#include \"cutlass_irrelevant.h\"\n" + code += "#include \"api.h\"\n" + code += "void one_api( const Param & param, int sm, cudaStream_t stream) {\n" + + code += " " + "if (sm == 70) \n" + code += " " + " " + self.gen_class_name + "_volta_impl(param.M, param.K0, param.Batch, const_cast(param.A0), " + for i in range(self.b2b_num): + code += helper.var_idx("const_cast(param.B", i) + "), " + code += helper.var_idx("const_cast(param.C", i) + "), " + code += helper.var_idx("param.D", i) + ", " + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + code += "param." + arg_name + ", " + code += "stream);\n" + code += " " + "else if(sm >= 75) \n" + code += " " + " " + self.gen_class_name + "_turing_impl(param.M, param.K0, param.Batch, const_cast(param.A0), " + for i in range(self.b2b_num): + code += helper.var_idx("const_cast(param.B", i) + "), " + code += helper.var_idx("const_cast(param.C", i) + "), " + code += helper.var_idx("param.D", i) + ", " + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + code += "param." + arg_name + ", " + code += "stream);\n" + code += " " + "else assert(0);\n" + code += "}\n" + return code + + def gen_code(self): + + turing_code = self.gen_turing.gen_wrapper() + volta_code = self.gen_volta.gen_wrapper() + cutlass_irrelevant_code = self.gen_CUTLASS_irrelevant_API() + + one_api_code = self.gen_one_api() + with open(self.output_dir + "one_api.cu", "w+") as f: + f.write(one_api_code) + + helper.write_2_headfile("cutlass_irrelevant.h", self.output_dir, cutlass_irrelevant_code) + + helper.write_2_headfile("api.h", self.output_dir, self.user_header_file + "\n" + turing_code + volta_code) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd465162b60954061ab7cd341281025dd55c0d0 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py @@ -0,0 +1,92 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +import gen_turing_and_volta as gen_basic + + +class gen_verify: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.name = gen_class_name + "_verify" + self.b2b_num = len(fuse_gemm_info) + self.params = [] + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.separate_cutlass = gen_basic.gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + self.gen_params() + self.output_dir = output_dir + + + def gen_code(self): + code = "" + code += self.user_header_file + code += self.separate_cutlass.gen_using(False) #False -> Turing, True -> Volta + + code_body = "" + for i in range(self.b2b_num): + code_body += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n" + code_body += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(Arguments_", i) + ", nullptr);\n" + + code_body += self.separate_cutlass.gen_run() + + code += ir.gen_func(self.name, self.params, code_body) + helper.write_2_headfile("cutlass_verify.h", self.output_dir, code) + + + def gen_params(self): + for i in range(self.b2b_num): + self.params.append( + ( + helper.var_idx("typename Gemm", i)+ "::Arguments", + helper.var_idx("Arguments_", i) + ) + ) + + + def get_params(self, declaration = True): + code = "" + if declaration: + for param in self.params: + code += param[0] + " " + param[1] + ";\n" + + return code + + + def gen_initialize(): + code = "" + initialize_code = self.separate_cutlass.gen_initialize() + + code = ir.gen_func("initialize", [[]]) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..f757d9522476d992a6df0e55cfc961c66b92c19e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py @@ -0,0 +1,135 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +def type_2_cutlass_type(input_type = "fp16"): + # float point type + if input_type == "fp32": + return "float" + if input_type == "bf16": + return "cutlass::bfloat16_t" + if input_type == "fp16": + return "cutlass::half_t" + + # integer type + if(input_type == "int32"): + return "int32_t" + if(input_type == "int8"): + return "int8_t" + + if input_type == 'Row': + return 'cutlass::layout::RowMajor' + if input_type == 'Col': + return 'cutlass::layout::ColumnMajor' + +def cvt_2_cutlass_shape(gemm_shape): + # gemm shape + if len(gemm_shape) == 3: + val = "cutlass::gemm::GemmShape<" \ + + str(gemm_shape[0]) + ", " \ + + str(gemm_shape[1]) + ", " \ + + str(gemm_shape[2]) + ">" + return val + + +def write_2_headfile(filename, file_dir, string): + with open(file_dir + filename, 'w') as f: + f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string) + +def var_idx(variable, index): + return variable + str(index) + + +def list_2_string(input_list, ): + rtn_string = "" + + cnt = 0 + + for element in input_list: + final = ", \n" + if cnt == len(input_list) - 1: + final = "\n" + cnt += 1 + rtn_string += str(element) + final + + return rtn_string + + +def get_epilogue_info(layer_info): + return layer_info['epilogue'] + +def get_epilogue_tp(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['tp'] + +def get_epilogue_add_bias_or_not(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['bias']['addbias'] + +def get_epilogue_add_bias_tp(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['bias']['bias_tp'] + +def get_epilogue_args(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['args'] + +def get_epilogue_bias_shape(layer_info): + bias_tp = get_epilogue_add_bias_tp(layer_info).lower() + mn_shape = layer_info['mnk'][:-1] + + if bias_tp == 'mat': + mn_shape[0] = 'M' + return mn_shape + elif bias_tp == 'vec': + mn_shape[0] = 1 + return mn_shape + else: + assert(0) + +def get_epilogue_bias_ldm(layer_info): + bias_tp = get_epilogue_add_bias_tp(layer_info).lower() + mn_shape = layer_info['mnk'][:-1] + + c_layout = layer_info['C_format'].lower() + + if c_layout != 'row': + assert(0) + + if bias_tp == 'mat': + return mn_shape[1] + elif bias_tp == 'vec': + return 0 + else: + assert(0) + +def get_epilogue_compute_tp(layer_info): + return layer_info['Acc_tp'] diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e79efc05f502bb95a10d1efc294453a7300db9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py @@ -0,0 +1,67 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import os + +class replace_fix_impl: + def __init__(self, src_dir, dst_dir, cutlass_deps_root): + self.src_dir = src_dir + self.dst_dir = dst_dir + self.cutlass_deps_root = cutlass_deps_root + + + + def gen_code(self): + for sub_dir in os.walk(self.src_dir): + files_in_sub_dir = sub_dir[2] + + src_dirs = sub_dir[0] + output_dirs = self.dst_dir + sub_dir[0][len(self.src_dir):] + + if not os.path.exists(output_dirs): + os.mkdir(output_dirs) + + for f in files_in_sub_dir: + with open(src_dirs +"/" + f, 'r') as current_file: + output_lines = [] + lines = current_file.readlines() + + for line in lines: + if(len(line) >= len("#include \"cutlass") and line[:len("#include \"cutlass")] == "#include \"cutlass"): + new_line = "#include \"" + self.cutlass_deps_root + line[len("#include \""):] + # print(new_line) + output_lines.append(new_line) + else: + output_lines.append(line) + + with open(output_dirs + "/" + f, "w+") as dest_file: + dest_file.writelines(output_lines) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h new file mode 100644 index 0000000000000000000000000000000000000000..cc34a412a14aa363aa5a614ed8bd378b80d54568 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h @@ -0,0 +1,292 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include + +template +__device__ +T add(T const & a, T const &b){ + return (a + b); +} + +template <> +__device__ +half2 add(half2 const & a, half2 const &b){ + return (__hadd2(a,b)); +} + +template +struct RELU{ + __device__ + T operator()(T const & a){ + return a > T(0) ? a : T(0); + } + __device__ + half2 operator()(half2 const & a){ + float2 a_fp32x2 = __half22float2(a); + a_fp32x2.x = a_fp32x2.x > 0.f ? a_fp32x2.x : 0.f; + a_fp32x2.y = a_fp32x2.y > 0.f ? a_fp32x2.y : 0.f; + if(a_fp32x2.x < 0.f || a_fp32x2.y < 0.f) + printf(" %f %f\n", a_fp32x2.x ,a_fp32x2.y); + return __float22half2_rn(a_fp32x2); + } +}; + +template +struct LEAKY_RELU{ + __device__ + T operator()(T const & a, T const & scale = half(1)){ + return a > T(0) ? a : scale * a; + } + __device__ + half2 operator()(half2 const & a, half const & scale = half(1)){ + half2 zero = __half2half2(half(0)); + half2 gt_zero = __hge2(a, zero); + half2 le_zero = __hle2(a, zero); + + + half2 scale_f16x2 = __half2half2(scale); + half2 mask_scale_f16x2 = __hfma2(le_zero, scale_f16x2, gt_zero); + return __hmul2(a, mask_scale_f16x2); + } +}; + +template +__global__ void leaky_and_activation(half* inout, half* bias, half scale, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + LEAKY_RELU Act; + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]),scale); + } + + } +} + + + +template +__global__ void leaky_and_activation(half* inout, half scale){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + LEAKY_RELU Act; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i], scale); + } + + } +} + + + +template +void leaky_and_activation(half* inout, half* bias, int m, int b, half scale, bool mat_bias){ + + dim3 grid(m, b); + if (bias == nullptr) + leaky_and_activation<<>>(inout, scale); + else + leaky_and_activation<<>>(inout, bias, scale, mat_bias); +} + +template +__global__ void relu_and_activation(half* inout, half* bias, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + RELU Act; + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i])); + } + + } +} + + + +template +__global__ void relu_and_activation(half* inout){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + RELU Act; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i]); + } + + } +} + + + +template +void relu_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ + dim3 grid(m, b); + if (bias == nullptr) + relu_and_activation<<>>(inout); + else + relu_and_activation<<>>(inout, bias, mat_bias); +} + + +template +__global__ void identity_and_activation(half* inout, half* bias, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (add(src_v[i],bias_v[i])); + } + + } +} + +template +__global__ void identity_and_activation(half* inout){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (src_v[i]); + } + + } +} + +template +void identity_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ + dim3 grid(m, b); + if (bias == nullptr) + identity_and_activation<<>>(inout); + else + identity_and_activation<<>>(inout, bias, mat_bias); +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a825033f3bb9d195daec3c7691d100693ed1ce44 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#define TI(tag) \ + cudaEvent_t _event_start_ ##tag; \ + cudaEvent_t _event_end_ ##tag; \ + float _event_time_ ##tag; \ + cudaEventCreate(& _event_start_ ##tag); \ + cudaEventCreate(& _event_end_ ##tag); \ + cudaEventRecord(_event_start_ ##tag); + +#define TO(tag, str, times) \ + cudaEventRecord(_event_end_ ##tag); \ + cudaEventSynchronize(_event_end_ ##tag); \ + cudaEventElapsedTime(&_event_time_ ##tag, _event_start_ ##tag, _event_end_ ##tag); \ + float _event_time_once_ ##tag = _event_time_ ##tag / times; \ + printf("%20s:\t %10.3fus\t", str, _event_time_once_ ##tag * 1000); \ + cudaDeviceSynchronize(); \ + printf("%20s string: %s\n",str, cudaGetErrorString(cudaGetLastError())); + +template +struct memory_unit{ + T* host_ptr; + T* device_ptr; + int size_bytes; + int elements; + void h2d(){ + cudaMemcpy(device_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice); + } + void d2h(){ + cudaMemcpy(host_ptr, device_ptr, size_bytes, cudaMemcpyDeviceToHost); + } + void free_all(){ + free(host_ptr); + cudaFree(device_ptr); + } + memory_unit(int elements_): size_bytes(elements_ * sizeof(T)), elements(elements_){ + host_ptr = (T*) malloc(elements_ * sizeof(T)); + cudaMalloc((void**)&device_ptr, elements_ * sizeof(T)); + } + void init(int abs_range = 1){ + for(int i = 0; i < elements; i++){ + host_ptr[i] = T(rand() % 100 / float(100) * 2 * abs_range - abs_range); + } + h2d(); + } +}; + +template +int check_result(T * a, T * b, int N){ + int cnt = 0; + for(int i = 0; i < N; i ++){ + float std = float(a[i]); + float my = float(b[i]); + + if(abs(std - my) / abs(std) > 1e-2) + { + // printf("my: %f , std: %f\n", my, std); + cnt++; + } + + } + printf("total err: %d / %d\n", cnt, N); + return cnt; +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/device/dual_gemm.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/device/dual_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..c6b6f30d30db3b7f2aed1463373ad7d55b87eb94 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/device/dual_gemm.h @@ -0,0 +1,499 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Performs a dual gemm in one fused kernel: +``` +D0 = epilogue0(X @ B0, C0) +D1 = epilogue1(X @ B1, C1) +D2 = element_wise(D0, D1) +``` +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" + +#include "../kernel/dual_gemm.h" +#include "../dual_gemm_common.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B0 matrix operand + typename LayoutB0_, + /// Layout type for B1 matrix operand + typename LayoutB1_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp0_, + typename EpilogueOutputOp1_, + typename EpilogueOutputOp2_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + bool StoreD0 = true, + bool StoreD1 = true, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class DualGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB0 = LayoutB0_; + using LayoutB1 = LayoutB1_; + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp0 = EpilogueOutputOp0_; + using EpilogueOutputOp1 = EpilogueOutputOp1_; + using EpilogueOutputOp2 = EpilogueOutputOp2_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp1::kCount; + static bool const kSplitKSerial = SplitKSerial; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + using LayoutScaleBias = layout::RowMajor; + /// Define the kernel + /// Define the threadblock-scoped matrix multiply-accumulate + static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented"); + static_assert(kStages >= 3, "Only multistage is implemented"); + using Mma0 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator>::ThreadblockMma; + using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator>::ThreadblockMma; + using DualMma = threadblock::DualMmaMultistage< + typename Mma0::Shape, + typename Mma0::IteratorA, + typename Mma0::SmemIteratorA, + Mma0::kCacheOpA, + typename Mma0::IteratorB, + typename Mma0::SmemIteratorB, + Mma0::kCacheOpB, + typename Mma1::IteratorB, + typename Mma1::SmemIteratorB, + typename Mma0::ElementC, + typename Mma0::LayoutC, + typename Mma0::Policy, + typename Mma1::Policy, + Mma0::kStages, + SharedMemoryClearOption::kNone + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue0 = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0, + EpilogueOutputOp0::kCount>::Epilogue; + using Epilogue1 = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using DualGemmKernel = kernel::DualGemm< + DualMma, + Epilogue0, Epilogue1, EpilogueOutputOp2, + ThreadblockSwizzle, kSplitKSerial, + kStoreD0, kStoreD1>; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + DualGemmMode mode; + GemmCoord problem_size; + TensorRef ref_A0; + TensorRef ref_B0; + TensorRef ref_C0; + TensorRef ref_D0; + TensorRef ref_B1; + TensorRef ref_C1; + TensorRef ref_D1; + TensorRef ref_D2; + typename EpilogueOutputOp0::Params epilogue0; + typename EpilogueOutputOp1::Params epilogue1; + typename EpilogueOutputOp2::Params epilogue2; + int split_k_slices; + + int batch_count; + int64_t batch_stride_A; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + DualGemmMode mode, + GemmCoord problem_size_, + TensorRef ref_A0_, + TensorRef ref_B0_, + TensorRef ref_C0_, + TensorRef ref_D0_, + TensorRef ref_B1_, + TensorRef ref_C1_, + TensorRef ref_D1_, + TensorRef ref_D2_, + typename EpilogueOutputOp0::Params epilogue0_ = + typename EpilogueOutputOp0::Params(), + typename EpilogueOutputOp1::Params epilogue1_ = + typename EpilogueOutputOp1::Params(), + typename EpilogueOutputOp2::Params epilogue2_ = + typename EpilogueOutputOp2::Params(), + int split_k_slices_ = 1, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0 + ): + mode(mode), + problem_size(problem_size_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_D0(ref_D0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + ref_D2(ref_D2_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + epilogue2(epilogue2_), + split_k_slices(split_k_slices_), + batch_count(batch_count), + batch_stride_A(batch_stride_A), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { + + } + }; + +private: + + /// Kernel parameters object + typename DualGemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + DualGemm() = default; + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (args.mode == DualGemmMode::kBatched && kSplitKSerial) { + return Status::kErrorInvalidProblem; + } + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + if (kStoreD0 != (args.ref_D0.data() != nullptr)) { + return Status::kErrorInternal; + } + if (kStoreD1 != (args.ref_D1.data() != nullptr)) { + return Status::kErrorInternal; + } + + Status status = DualGemmKernel::can_implement( + args.problem_size, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_D0, + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.ref_D2 + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + if (kSplitKSerial && args.split_k_slices > 1) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename DualGemmKernel::Params{ + args.mode, + args.problem_size, + grid_shape, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_D0, + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.ref_D2, + args.epilogue0, + args.epilogue1, + args.epilogue2, + reinterpret_cast(workspace), + args.batch_stride_A, + args.batch_stride_B0, + args.batch_stride_B1, + args.batch_stride_C, + args.batch_stride_D, + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); + params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); + params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); + params_.ref_D0.reset(args.ref_D0.data()); + params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); + params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); + params_.ref_D1.reset(args.ref_D1.data()); + params_.ref_D2.reset(args.ref_D2.data()); + params_.output_op_0 = args.epilogue0; + params_.output_op_1 = args.epilogue1; + params_.output_op_2 = args.epilogue2; + params_.semaphore = reinterpret_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(DualGemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_common.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_common.h new file mode 100644 index 0000000000000000000000000000000000000000..12ebe05ea760c97dcdc895266bc82ccb97b5abbf --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_common.h @@ -0,0 +1,52 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines common types used for all DualGemm operators. +*/ +#pragma once + +namespace cutlass { +namespace gemm { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class DualGemmMode { + kGemm, + kBatched, + kInvalid +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_run.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_run.h new file mode 100644 index 0000000000000000000000000000000000000000..d042f93cac7f9830c2ce780c482c7034c64b8ac9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_run.h @@ -0,0 +1,938 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "dual_gemm_common.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +template < + typename OutputOp, + typename Element, + typename Layout> +struct TensorEpilogueForEachFunc { + /// View type + using TensorView = cutlass::TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view_x0; + TensorView view_x1; + TensorView view_y; + OutputOp output_op; + + + // + // Methods + // + + Params( + TensorView view_x0_ = TensorView(), + TensorView view_x1_ = TensorView(), + TensorView view_y_ = TensorView(), + OutputOp output_op_ = OutputOp(typename OutputOp::Params{}) + ): + view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) { + } + }; + + Params params; + + CUTLASS_DEVICE + TensorEpilogueForEachFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + Element const & x0 = params.view_x0.at(coord); + Element const & x1 = params.view_x1.at(coord); + Element& y = params.view_y.at(coord); + y = params.output_op(x0, x1); + } +}; + +template < + typename OutputOp, + typename Element, + typename Layout> +void TensorEpilogueForEach( + cutlass::TensorView x0, + cutlass::TensorView x1, + cutlass::TensorView y) { + + using Func = TensorEpilogueForEachFunc; + using Params = typename Func::Params; + + cutlass::reference::device::TensorForEach( + y.extent(), + Params(x0, x1, y) + ); +} + +//////////////////////////////////////////////////////////////////////////////// + +template +struct NonFusedDualGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + NonFusedDualGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool is_profiling = true, + bool relu = false, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()}); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()}); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0.sync_device(); + reference_D0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1; + typename Gemm0::Arguments arguments_0{ + problem_size, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + tensor_D0.device_ref(), + {alpha0, beta0}, + split_k_slices + }; + + split_k_slices = Gemm1::kSplitKSerial ? 2 : 1; + typename Gemm1::Arguments arguments_1{ + problem_size, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + {alpha1, beta1}, + split_k_slices + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + // Allocate workspace memory + cutlass::device_memory::allocation workspace0(gemm_op_0.get_workspace_size(arguments_0)); + cutlass::device_memory::allocation workspace1(gemm_op_1.get_workspace_size(arguments_1)); + + cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get()); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1, workspace1.get()); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = gemm_op_0(); + CUTLASS_CHECK(status); + status = gemm_op_1(); + CUTLASS_CHECK(status); + } + + if (is_profiling) { + // + // Profile the GEMM + // + + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; + } + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size, + alpha1, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + // Wait for kernels to finish + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed0 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + CHECK_TRUE(passed0); + + bool passed1 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + CHECK_TRUE(passed1); + if (!passed0 || !passed1) { + + std::stringstream fname; + + fname << "error_DualGemm_device_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed0 && passed1; + } +}; + +template +struct DualFusedGemmRun +{ + + using DualGemm = DualGemm_; + using ElementAccumulator = typename DualGemm::ElementAccumulator; + using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute; + using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + DualFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(1), + int batch_count = 1, + bool broadcast_b1 = false, + bool is_profiling = true, + bool relu = false, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename DualGemm::ElementA, + typename DualGemm::LayoutA> tensor_A0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k())); + + cutlass::HostTensor< + typename DualGemm::ElementB, + typename DualGemm::LayoutB0> tensor_B0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_C0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutScaleBias> tensor_Bias0({batch_count, problem_size.n()}); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementB, + typename DualGemm::LayoutB1> tensor_B1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); + if (broadcast_b1) { + tensor_B1.resize({problem_size.k(), batch_count}); + } + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_C1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutScaleBias> tensor_Bias1({batch_count, problem_size.n()}); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D2( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D2( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + tensor_D2.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D2.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D0.sync_device(); + tensor_D1.sync_device(); + tensor_D2.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + reference_D2.sync_device(); + + // + // Batch strides (irrelevant when batch_count == 1) + // + + int64_t batch_stride_A = problem_size.m() * problem_size.k(); + int64_t batch_stride_B0 = problem_size.k() * problem_size.n(); + int64_t batch_stride_B1 = problem_size.k() * problem_size.n(); + if (broadcast_b1) { + // B1 is a (column) vector + batch_stride_B1 = problem_size.k(); + } + int64_t batch_stride_Bias = problem_size.n(); + int64_t batch_stride_D = problem_size.m() * problem_size.n(); + + // + // Initialize the GEMM operator + // + + int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1; + typename cutlass::TensorRef nullptr_ref{}; + decltype(nullptr_ref) ref_B0, ref_B1; + if (beta0 != ElementCompute(0)) { + ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}; + } + if (beta1 != ElementCompute(0)) { + ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; + } + typename DualGemm::Arguments arguments{ + (batch_count > 1 ? + cutlass::gemm::DualGemmMode::kBatched : + cutlass::gemm::DualGemmMode::kGemm), + problem_size, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + ref_B0, + DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, + (broadcast_b1 ? + typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) : + tensor_B1.device_ref()), + ref_B1, + DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, + tensor_D2.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + {}, + split_k_slices, + batch_count, + batch_stride_A, + batch_stride_B0, + batch_stride_B1, + batch_stride_Bias, + batch_stride_D, + }; + + // + // Run the GEMM + // + + DualGemm b2b_gemm_op; + + cutlass::device_memory::allocation workspace(b2b_gemm_op.get_workspace_size(arguments)); + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + CUTLASS_CHECK(status); + + status = b2b_gemm_op.initialize(arguments, workspace.get()); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + if (is_profiling) { + // + // Profile the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + } + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + tensor_D2.sync_host(); + + // + // Verify + // + + using GemmUniversal0 = cutlass::gemm::device::GemmUniversal< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB0, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator + >; + + GemmUniversal0 reference_gemm0; + + typename GemmUniversal0::Arguments args0 { + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : + cutlass::gemm::GemmUniversalMode::kGemm), + problem_size, + batch_count, + {alpha0, beta0}, + tensor_A0.device_data(), + tensor_B0.device_data(), + tensor_Bias0.device_data(), + reference_D0.device_data(), + batch_stride_A, + batch_stride_B0, + batch_stride_Bias, + batch_stride_D, + tensor_A0.stride(0), + tensor_B0.stride(0), + 0, // zero stride for the bias vector + reference_D0.stride(0), + }; + + status = reference_gemm0.can_implement(args0); + CUTLASS_CHECK(status); + status = reference_gemm0(args0); + CUTLASS_CHECK(status); + + using GemmUniversal1 = cutlass::gemm::device::GemmUniversal< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB1, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator + >; + + GemmUniversal1 reference_gemm1; + + typename GemmUniversal1::Arguments args1 { + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : + cutlass::gemm::GemmUniversalMode::kGemm), + problem_size, + batch_count, + {alpha1, beta1}, + tensor_A0.device_data(), + tensor_B1.device_data(), + tensor_Bias1.device_data(), + reference_D1.device_data(), + batch_stride_A, + batch_stride_B1, + batch_stride_Bias, + batch_stride_D, + tensor_A0.stride(0), + (broadcast_b1 ? 0 : tensor_B1.stride(0)), + 0, // zero stride for the bias vector + reference_D1.stride(0), + }; + + status = reference_gemm1.can_implement(args1); + CUTLASS_CHECK(status); + status = reference_gemm1(args1); + CUTLASS_CHECK(status); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + TensorEpilogueForEach(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view()); + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + reference_D2.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0); + + bool passed_out0 = true; + if (DualGemm::kStoreD0) { + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + passed_out0 = cutlass::reference::host::TensorEquals( + reference_D0.host_view(), + tensor_D0.host_view()); + } + CHECK_TRUE(passed_out0); + + bool passed_out1 = true; + if (DualGemm::kStoreD1) { + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + passed_out1 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + } + CHECK_TRUE(passed_out1); + + bool passed_out2 = cutlass::reference::host::TensorEquals( + reference_D2.host_view(), + tensor_D2.host_view()); + CHECK_TRUE(passed_out2); + + bool passed = passed_out0 && passed_out1 && passed_out2; + if (!passed) + { + std::stringstream fname; + + fname << "error_DualGemm_device_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference0 =\n" << reference_D0.host_view() + << "\nComputed0 =\n" << tensor_D0.host_view() + << "\n\nReference1 =\n" << reference_D1.host_view() + << "\nComputed1 =\n" << tensor_D1.host_view() + << "\n\nReference2 =\n" << reference_D2.host_view() + << "\nComputed2 =\n" << tensor_D2.host_view(); + } + //std::cout << "A0 " << tensor_A0.host_view() << std::endl; + // std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; + // std::cout << "reference_D1 " << reference_D1.host_view() << std::endl; + // std::cout << "reference_D2 " << reference_D2.host_view() << std::endl; + //std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..2ff5b5e24f8927fe2c71e3a428203d52298e6e86 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "../threadblock/dual_mma_multistage.h" +#include "../threadblock/dual_epilogue.h" +#include "../dual_gemm_common.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue0_, ///! Epilogue + typename Epilogue1_, ///! Epilogue + typename OutputOp2_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled. + bool StoreD0, + bool StoreD1 +> +struct DualGemm { + + using DualMma = DualMma_; + + using Epilogue0 = Epilogue0_; + using Epilogue1 = Epilogue1_; + using OutputOp0 = typename Epilogue0::OutputOp; + using OutputOp1 = typename Epilogue1::OutputOp; + using OutputOp2 = OutputOp2_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static constexpr bool kStoreD0 = StoreD0; + static constexpr bool kStoreD1 = StoreD1; + + using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue< + typename Epilogue0::Shape, + typename Epilogue0::WarpMmaOperator, + Epilogue0::kPartitionsK, + typename Epilogue0::OutputTileIterator, + typename Epilogue0::AccumulatorFragmentIterator, + typename Epilogue0::WarpTileIterator, + typename Epilogue0::SharedLoadIterator, + OutputOp0, + OutputOp1, + OutputOp2, + typename Epilogue0::Padding, + kStoreD0, + kStoreD1, + Epilogue0::kFragmentsPerIteration, + true // IterationsUnroll + >; + + using ElementA = typename DualMma::IteratorA::Element; + using ElementB = typename DualMma::IteratorB0::Element; + using ElementC = typename DualEpilogue::OutputTileIterator::Element; + + static bool const kSplitKSerial = SplitKSerial; + static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1), + "Split-K serial requires buffers for D0/D1 for reduction"); + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename DualMma::WarpCount; + static int const kThreadCount = 32 * WarpCount0::kCount; + + /// Parameters structure + struct Params { + DualGemmMode mode; + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + // Mma0 + typename DualMma::IteratorA::Params params_A0; + typename DualMma::IteratorA::TensorRef ref_A0; + typename DualMma::IteratorB0::Params params_B0; + typename DualMma::IteratorB0::TensorRef ref_B0; + typename Epilogue0::OutputTileIterator::Params params_C0; + typename Epilogue0::OutputTileIterator::TensorRef ref_C0; + typename Epilogue0::OutputTileIterator::Params params_D0; + typename Epilogue0::OutputTileIterator::TensorRef ref_D0; + typename OutputOp0::Params output_op_0; + + // Mma1 + typename DualMma::IteratorB1::Params params_B1; + typename DualMma::IteratorB1::TensorRef ref_B1; + typename Epilogue1::OutputTileIterator::Params params_C1; + typename Epilogue1::OutputTileIterator::TensorRef ref_C1; + typename Epilogue1::OutputTileIterator::Params params_D1; + typename Epilogue1::OutputTileIterator::TensorRef ref_D1; + typename OutputOp1::Params output_op_1; + + typename Epilogue1::OutputTileIterator::Params params_D2; + typename Epilogue1::OutputTileIterator::TensorRef ref_D2; + typename OutputOp2::Params output_op_2; + + int *semaphore; + int gemm_k_size; + + int64_t batch_stride_A; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + DualGemmMode mode, + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + // Mma0: D0 = A @ B0 + C0 + typename DualMma::IteratorA::TensorRef ref_A0, + typename DualMma::IteratorB0::TensorRef ref_B0, + typename Epilogue0::OutputTileIterator::TensorRef ref_C0, + typename Epilogue0::OutputTileIterator::TensorRef ref_D0, + // Mma1: D1 = A @ B1 + C1 + typename DualMma::IteratorB1::TensorRef ref_B1, + typename Epilogue1::OutputTileIterator::TensorRef ref_C1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D1, + + typename Epilogue1::OutputTileIterator::TensorRef ref_D2, + typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), + typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), + typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(), + int *workspace = nullptr, + int64_t batch_stride_A = 1, + int64_t batch_stride_B0 = 1, + int64_t batch_stride_B1 = 1, + int64_t batch_stride_C = 1, + int64_t batch_stride_D = 1 + ): + mode(mode), + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + // Mma0 + params_A0(ref_A0.layout()), + ref_A0(ref_A0), + params_B0(ref_B0.layout()), + ref_B0(ref_B0), + params_C0(ref_C0.layout()), + ref_C0(ref_C0), + params_D0(ref_D0.layout()), + ref_D0(ref_D0), + // Mma1 + params_B1(ref_B1.layout()), + ref_B1(ref_B1), + params_C1(ref_C1.layout()), + ref_C1(ref_C1), + params_D1(ref_D1.layout()), + ref_D1(ref_D1), + params_D2(ref_D2.layout()), + ref_D2(ref_D2), + output_op_0(output_op_0), + output_op_1(output_op_1), + output_op_2(output_op_2), + batch_stride_A(batch_stride_A), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { + + int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size = gemm_k_iterations * DualMma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename DualMma::SharedStorage main_loop; + typename DualEpilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DualGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename DualMma::IteratorA::TensorRef ref_A0, + typename DualMma::IteratorB0::TensorRef ref_B0, + typename Epilogue0::OutputTileIterator::TensorRef ref_C0, + typename Epilogue0::OutputTileIterator::TensorRef ref_D0, + typename DualMma::IteratorB1::TensorRef ref_B1, + typename Epilogue1::OutputTileIterator::TensorRef ref_C1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D2) { + + static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements; + static int const kAlignmentB = DualMma::IteratorB0::AccessType::kElements; + static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A0, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B0, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B1, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D2, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A0 = static_cast(params.ref_A0.data()); + ElementB *ptr_B0 = static_cast(params.ref_B0.data()); + ElementB *ptr_B1 = static_cast(params.ref_B1.data()); + + // + // Fetch pointers based on mode. + // + if (params.mode == DualGemmMode::kGemm) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == DualGemmMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A0{ + threadblock_tile_offset.m() * DualMma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B0{ + offset_k, + threadblock_tile_offset.n() * DualMma::Shape::kN + }; + + cutlass::MatrixCoord tb_offset_B1{ + offset_k, + threadblock_tile_offset.n() * DualMma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename DualMma::IteratorA iterator_A0( + params.params_A0, + ptr_A0, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A0); + + typename DualMma::IteratorB0 iterator_B0( + params.params_B0, + ptr_B0, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B0); + + typename DualMma::IteratorB1 iterator_B1( + params.params_B1, + ptr_B1, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B1); + + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + + // Construct thread-scoped matrix multiply + typename DualMma::FragmentC accum0; + typename DualMma::FragmentC accum1; + accum0.clear(); + accum1.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + + DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accum0, accum1, + iterator_A0, iterator_B0, iterator_B1, + accum0, accum1); + } + + // + // Epilogue + // + + OutputOp0 output_op_0(params.output_op_0); + OutputOp1 output_op_1(params.output_op_1); + OutputOp2 output_op_2(params.output_op_2); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * DualMma::Shape::kM, + threadblock_tile_offset.n() * DualMma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C0 = static_cast(params.ref_C0.data()); + ElementC *ptr_C1 = static_cast(params.ref_C1.data()); + ElementC *ptr_D0 = static_cast(params.ref_D0.data()); + ElementC *ptr_D1 = static_cast(params.ref_D1.data()); + ElementC *ptr_D2 = static_cast(params.ref_D2.data()); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == DualGemmMode::kGemm) { + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == DualGemmMode::kBatched) { + ptr_C0 += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D0 += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D2 += threadblock_tile_offset.k() * params.batch_stride_D; + } + + // Tile iterator loading from source tensor. + typename Epilogue0::OutputTileIterator iterator_C0( + params.params_C0, + ptr_C0, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_C1( + params.params_C1, + ptr_C1, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue0::OutputTileIterator iterator_D0( + params.params_D0, + ptr_D0, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_D1( + params.params_D1, + ptr_D1, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_D2( + params.params_D2, + ptr_D2, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + DualEpilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C0 = iterator_D0; + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + typename Epilogue0::OutputTileIterator source_iters[] = { + iterator_C0, iterator_C1 + }; + const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1); + epilogue( + output_op_0, output_op_1, output_op_2, + iterator_D0, iterator_D1, iterator_D2, + accum0, accum1, + source_iters, + writeToD2 + ); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/test_run.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/test_run.h new file mode 100644 index 0000000000000000000000000000000000000000..c92af193d21860b69ba19275dbff464c95a96a71 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/test_run.h @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#include + +// Run tests on GPUs + +int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { + + bool supported = false; + + int arch_major = arch / 10; + int arch_minor = arch - arch / 10 * 10; + + if(arch_major >= 8) { + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { + supported = true; + } + } + else if(arch_major >= 7) { + // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. + if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { + supported = true; + } + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < arch_major || (props.major == arch_major && props.minor < arch_minor) ) { + supported = false; + } + + if (!supported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + std::cout << "This example isn't supported on current architecture" << std::endl; + return 0; + } + + bool pass = true; + + std::cout << "Device: " << props.name << std::endl; + std::cout << "Arch: SM" << arch << std::endl; + std::cout << "Test: " << test_name << std::endl; + for(auto func : test_funcs) { + pass &= func(); + } + + + if(pass) + return 0; + else + return -1; + +} + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h new file mode 100644 index 0000000000000000000000000000000000000000..003b6f6c60d60ffd8660f531e7d332f3caec4dd7 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h @@ -0,0 +1,150 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LeftSiLUAndMul { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params{}; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LeftSiLUAndMul(Params const &/*params*/) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(converted_lhs); + return compute_to_output(mul(silu_lhs, converted_rhs)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()( + ElementAccumulator const& lhs, + ElementAccumulator const& rhs + ) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(convert_lhs); + return ElementOutput(mul(silu_lhs, convert_rhs)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h new file mode 100644 index 0000000000000000000000000000000000000000..bcb34d8970b2f8e4901426db4e6f250934cbbad3 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h @@ -0,0 +1,424 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once +#include "cutlass/array.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + ///< Output operator + typename OutputOp0_, + typename OutputOp1_, + typename OutputOp2_, + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) + bool StoreD0 = true, + bool StoreD1 = true, + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value) +> +class DualEpilogue { + +public: + + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp0 = OutputOp0_; + using OutputOp1 = OutputOp1_; + using OutputOp2 = OutputOp2_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + struct SharedStorage { + using Element = typename WarpTileIterator::Element; + + /// Tensor reference to shared memory allocation + using TensorRef = typename WarpTileIterator::TensorRef; + + /// Logical shape of the shared memory tile written to by all warps. + using Shape = typename Base::Shape; + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = typename Base::SharedStorage::StorageShape; + + // + // Data members + // + + AlignedBuffer storage[2]; + + // + // Methods + // + + /// Returns a tensor reference to the shared memory buffer + CUTLASS_DEVICE + TensorRef reference(int i) { + return TensorRef( + storage[i].data(), + Layout::packed({StorageShape::kRow, StorageShape::kColumn})); + } + }; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles; + +public: + + static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator0_; + SharedLoadIterator shared_load_iterator1_; + + /// Stores a warp's fragment of accumulators to SMEM + WarpTileIterator warp_tile_iterator0_; + WarpTileIterator warp_tile_iterator1_; + +public: + + /// Constructor + CUTLASS_DEVICE + DualEpilogue( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + shared_load_iterator0_(shared_storage.reference(0), thread_idx), + shared_load_iterator1_(shared_storage.reference(1), thread_idx), + warp_tile_iterator0_(shared_storage.reference(0), lane_idx), + warp_tile_iterator1_(shared_storage.reference(1), lane_idx) + { + int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); + int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); + int warp_m = warp_mn % WarpCount::kM; + int warp_n = warp_mn / WarpCount::kM; + + MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; + + warp_tile_iterator0_.add_tile_offset(warp_offset); + warp_tile_iterator1_.add_tile_offset(warp_offset); + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + OutputTileIterator dest0, + OutputTileIterator dest1, + OutputTileIterator dest2, + AccumulatorTile const &accumulator0, + AccumulatorTile const &accumulator1, + OutputTileIterator source_iterator[2], + bool writeToD2 // true if it's the final split-k + ) { + // TODO: Implement when no source is needed + + typename OutputTileIterator::Fragment source_fragment[2]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_fragment[i].clear(); + } + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1}; + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_iterator[i].load(source_fragment[i]); + ++source_iterator[i]; + } + + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_needed>::push( + iter, accum_fragment_iterator[0], this->warp_tile_iterator0_); + acc2smem_source_needed>::push( + iter, accum_fragment_iterator[1], this->warp_tile_iterator1_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK]; + typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK]; + + shared_load_iterator0_.load(aligned_accum_fragment0[0]); + shared_load_iterator1_.load(aligned_accum_fragment1[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + if (kPartitionsK > 1) { + + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator0_.load(aligned_accum_fragment0[i]); + shared_load_iterator1_.load(aligned_accum_fragment1[i]); + aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]); + aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]); + } + + shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment[3]; + + apply_output_operator_(output_fragment, + output_op0, output_op1, output_op2, + aligned_accum_fragment0[0], aligned_accum_fragment1[0], + source_fragment); + + + // + // Store the final result + // + + if (kStoreD0) { + dest0.store(output_fragment[0]); + ++dest0; + } + if (kStoreD1) { + dest1.store(output_fragment[1]); + ++dest1; + } + if (writeToD2) { + dest2.store(output_fragment[2]); + ++dest2; + } + } + } + +private: + + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE + static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment (&output_fragment)[3], + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + typename SharedLoadIterator::Fragment const& aligned_accum_fragment0, + typename SharedLoadIterator::Fragment const& aligned_accum_fragment1, + typename OutputTileIterator::Fragment const (&source_fragment)[2]) { + + OutputAccessType* output_frag_ptr[3] = { + reinterpret_cast(&output_fragment[0]), + reinterpret_cast(&output_fragment[1]), + reinterpret_cast(&output_fragment[2]) + }; + + AccumulatorAccessType const *compute_frag_ptr[2] = { + reinterpret_cast(&aligned_accum_fragment0), + reinterpret_cast(&aligned_accum_fragment1) + }; + + OutputAccessType const *source_frag_ptr[2] = { + reinterpret_cast(&source_fragment[0]), + reinterpret_cast(&source_fragment[1]) + }; + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + // Call the output operators + output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); + output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); + output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..754719033e3ac19c7c6101db94d5a45b5d4e8b97 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h @@ -0,0 +1,232 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// B1-specific version of the policy (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DualMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy0::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator0::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB0 = + MatrixShape; + using ShapeB1 = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B0; + AlignedBuffer operand_B1; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator0::LayoutA LayoutA() { + return Operator0::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator0::LayoutB LayoutB0() { + return Operator0::LayoutB::packed({ShapeB0::kRow, ShapeB0::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator1::LayoutB LayoutB1() { + return Operator1::LayoutB::packed({ShapeB1::kRow, ShapeB1::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB0 operand_B0_ref() { + return TensorRefB0{operand_B0.data(), LayoutB0()}; + } + CUTLASS_HOST_DEVICE + TensorRefB1 operand_B1_ref() { + return TensorRefB1{operand_B1.data(), LayoutB1()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator0::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator0::IteratorB warp_tile_iterator_B0_; + typename Operator1::IteratorB warp_tile_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DualMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx), + warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..5109b41057ce370b1a18019dfe1709611446174f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h @@ -0,0 +1,775 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" +#include "dual_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B0 operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B0 operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over tiles of B1 operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B1 operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// B1-specific version of the policy (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class DualMmaMultistage : + public DualMmaBase { +public: + ///< Base class + using Base = DualMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B0 operand in global memory + using IteratorB0 = IteratorB0_; + ///< Iterates over tiles of B1 operand in global memory + using IteratorB1 = IteratorB1_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorB1 = SmemIteratorB1_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator0::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DualMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx), + smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB0 &iterator_B0, IteratorB1 &iterator_B1, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B0.set_iteration_index(group_start_B * + IteratorB0::kAccessesPerVector); + iterator_B1.set_iteration_index(group_start_B * + IteratorB1::kAccessesPerVector); + this->smem_iterator_B0_.set_iteration_index(group_start_B); + this->smem_iterator_B1_.set_iteration_index(group_start_B); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B0.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + } + + ++iterator_B0; + } + ++this->smem_iterator_B0_; + } + } + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + } + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum0, + FragmentC &accum1, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB0 iterator_B0, + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC const &src_accum0, + FragmentC const &src_accum1 + ) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B0.set_iteration_index(0); + iterator_B1.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + } + + ++this->smem_iterator_B0_; + } + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum0 = src_accum0; + accum1 = src_accum1; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + typename IteratorB0::AccessType zero_B; + zero_B.clear(); + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 last_smem_iterator_B0(this->smem_iterator_B0_); + last_smem_iterator_B0.set_iteration_index(0); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B0.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B0; + } + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 last_smem_iterator_B1(this->smem_iterator_B1_); + last_smem_iterator_B1.set_iteration_index(0); + + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B1.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B1; + } + } + + // Waits until stages up to the previous (kStages-2)th stage have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator0 warp_mma0; + Operator1 warp_mma1; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B0_; + ++this->warp_tile_iterator_B1_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); + warp_mma1.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum0, tmp_accum1; + + if (platform::is_same::value + || platform::is_same::value) { + + tmp_accum0.clear(); + tmp_accum1.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B0_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k > 0) { + warp_mma0.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + warp_mma1.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + } + + if (platform::is_same::value + || platform::is_same::value) { + + warp_mma0( + tmp_accum0, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + tmp_accum0 + ); + warp_mma1( + tmp_accum1, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum1 + ); + + if (warp_mma_k == 0) { + accum0 = plus_accum(accum0, tmp_accum0); + accum1 = plus_accum(accum1, tmp_accum1); + tmp_accum0.clear(); + tmp_accum1.clear(); + } + } else { + warp_mma0( + accum0, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + warp_mma1( + accum1, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum1 + ); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until stages up to the previous (kStages-2)th stage have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma0.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + warp_mma1.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + } + + if (platform::is_same::value + || platform::is_same::value) { + accum0 = plus_accum(accum0, tmp_accum0); + accum1 = plus_accum(accum1, tmp_accum1); + } + + // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/51_hopper_gett/gett_kernel.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/51_hopper_gett/gett_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..609168c35480c88314229e15a4b8e405e35f5e7f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/51_hopper_gett/gett_kernel.cuh @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +namespace example { + +// +// GETT entry point +// +template < + class ProblemShapeMNKL, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class ElementAccumulator, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class ElementEpilogue> +cutlass::Status +gett_kernel( + ProblemShapeMNKL problem_shape_mnkl, + ElementA const* ptr_A, StrideA stride_a_mkl, + ElementB const* ptr_B, StrideB stride_b_nkl, + ElementAccumulator _, + ElementC const* ptr_C, StrideC stride_c_mnl, + ElementD * ptr_D, StrideD stride_d_mnl, + ElementEpilogue alpha, ElementEpilogue beta, + cudaStream_t stream = 0) { + using namespace cute; + + // TileShape -- GETT configuration + // Specify the number of elements to take from each mode + // BLK_M = (M0,M1,...) BLK_N = (M0,M1,...) BLK_K = (K0,K1,...) + + // Take 128 from m0, 128 from n0, 64 from k0 + using TileShape = Shape, Shape<_128>, Shape<_64>>; + + /* Other examples: + * Take 32 elements from m0 and 4 elements from m1 + * Take 64 elements from n0 and 2 elements from n1 + * Take 8 elements from k0 and 8 elements from k1 + **/ + // using TileShape = Shape, Shape<_64,_2>, Shape<_8,_8>>; + + using EpilogueThreadOp = cutlass::epilogue::thread::LinearCombination< + ElementD, 1, ElementAccumulator, ElementEpilogue, cutlass::epilogue::thread::ScaleType::Default, + cutlass::FloatRoundStyle::round_to_nearest, ElementC>; + + // No changes are required to the default epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + ElementC, + StrideC, + StrideD, + EpilogueThreadOp, + cutlass::gemm::EpilogueDefault>>; + + // CollectiveMma for GETTs can be built using the CollectiveBuilders + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 128 / cutlass::sizeof_bits::value, + ElementB, StrideB, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShape, Shape<_1,_2,_1>, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + // The GETT kernel is a composition of a collective mainloop and epilogue, just like any 3.x GEMM + using GettKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShapeMNKL, + CollectiveMainloop, + CollectiveEpilogue>; + + using GettOperator = cutlass::gemm::device::GemmUniversalAdapter; + + typename GettOperator::Arguments args { + cutlass::gemm::GemmUniversalMode::kBatched, + problem_shape_mnkl, + { ptr_A, stride_a_mkl, ptr_B, stride_b_nkl }, + { {alpha, beta}, ptr_C, stride_c_mnl, ptr_D, stride_d_mnl } + }; + +#if CUTLASS_DEBUG_TRACE_LEVEL > 0 + print("Problem shape:"); + print("\tM: "); print(cute::get<0>(problem_shape_mnkl)); print("\n"); + print("\tN: "); print(cute::get<1>(problem_shape_mnkl)); print("\n"); + print("\tK: "); print(cute::get<2>(problem_shape_mnkl)); print("\n"); + print("\tL: "); print(cute::get<3>(problem_shape_mnkl)); print("\n"); + print("TileSape:"); print(TileShape{}); print("\n"); +#endif + + GettOperator op; + return op(args, stream); +} + +} // namespace example diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..959e1e6392cf68f4e2b2da8e821bfa809e27b93c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -0,0 +1,421 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/tensor.hpp" + +#include "gather_tensor.hpp" + +namespace cutlass { + ///Forward declaration + struct CudaHostAdapter; +} + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_, + class GatherA_, + class GatherB_ +> +class GemmGather +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "Non-persistent warp-specialized kernel does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + using GatherA = GatherA_; + using GatherB = GatherB_; + + // Kernel level shared memory storage + struct SharedStorage { + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16, _2> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA; + using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB; + static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same."); + + static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(cute::size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; + static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); + + static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + GatherA gather_A{}; + GatherB gather_B{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + GatherA gather_A{}; + GatherB gather_B{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + args.gather_A, + args.gather_B + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + return implementable; + } + + static + size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = Shape<_1,_1,_1>{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_tiled_cta_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int warp_group_idx = canonical_warp_group_idx(); + CUTLASS_ASSERT(warp_group_idx < NumWarpGroups); + WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer; + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + mainloop_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + epi_load_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Represent the full tensors + Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l) + Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l) + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with m_coord and n_coord + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Get pipeline iterators and increments from tensor shapes + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + auto k_tile_count = size<2>(gA); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + // Wait for all threads in the thread block + __syncthreads(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + if (warp_group_role == WarpGroupRole::Producer) { + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, + gB, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + thread_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b4cafccbb555a14256fecff25e55f0f514e86173 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/numeric/math.hpp" + +namespace example +{ + +// Naive grid-stride loop implementation of gather +template +__global__ void +gather_kernel(Element const * __restrict__ input, + Element * __restrict__ output, + Func func, + int num_elems_input, + int num_elems_output, + cutlass::FastDivmod stride_divmod) +{ + Element const * input_b = input + blockIdx.z * num_elems_input; + Element * output_b = output + blockIdx.z * num_elems_output; + int tidx = threadIdx.x + blockIdx.x * blockDim.x; + for (int k = tidx; k < num_elems_output; k += blockDim.x * gridDim.x) { + int i,j; + stride_divmod(j, i, k); + output_b[k] = input_b[i + func(j) * stride_divmod.divisor]; + } +} + +// Gather elements along strided dimension of the tensor according to given indices +template +void +gather(Element const * input, + Element * output, + Func func, + int batch_size, + int num_elems_input, + int num_elems_output, + int stride, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int num_elems_input_upcast = num_elems_input / factor; + int num_elems_output_upcast = num_elems_output / factor; + + cutlass::FastDivmod stride_divmod(stride_upcast); + dim3 blocks(hw_info.sm_count, 1, batch_size); + gather_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + func, + num_elems_input_upcast, + num_elems_output_upcast, + stride_divmod); +} + +// Naive grid-stride loop implementation of scatter +template +__global__ void +scatter_kernel(Element const * __restrict__ input, + Element * __restrict__ output, + Func func, + int num_elems_input, + int num_elems_output, + cutlass::FastDivmod stride_divmod) +{ + Element const * input_b = input + blockIdx.z * num_elems_input; + Element * output_b = output + blockIdx.z * num_elems_output; + int tidx = threadIdx.x + blockIdx.x * blockDim.x; + for (int k = tidx; k < num_elems_input; k += blockDim.x * gridDim.x) { + int i,j; + stride_divmod(j, i, k); + output_b[i + func(j) * stride_divmod.divisor] = input_b[k]; + } +} + +// Gather elements along strided dimension of the tensor according to given indices +template +void +scatter(Element const * input, + Element * output, + Func func, + int batch_size, + int num_elems_input, + int num_elems_output, + int stride, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int num_elems_input_upcast = num_elems_input / factor; + int num_elems_output_upcast = num_elems_output / factor; + + cutlass::FastDivmod stride_divmod(stride_upcast); + dim3 blocks(hw_info.sm_count, 1, batch_size); + scatter_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + func, + num_elems_input_upcast, + num_elems_output_upcast, + stride_divmod); +} + +} // namespace example diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4292e018c16f58374181dfaf70d04e7bbd10fd44 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" + +#include "gather_tensor.hpp" + +namespace cutlass::epilogue::collective { + +/// Applies an element wise operation to all elements within the fragment +/// and scatter-writes them out to destination storage. +/// GatherC and ScatterD are types of user-defined functions that apply the +/// transoformation of the strided coordinate (e.g. through an index array). +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_, + class GatherC_, + class ScatterD_ +> +class EpilogueGatherScatter { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + // Every epilogue needs these two GmemTiledCopy{C,D} aliases. + // If you don't know what they should be, just use void. + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + using GatherC = GatherC_; + using ScatterD = ScatterD_; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread_params{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + GatherC gather_C{}; + ScatterD scatter_D{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + EpilogueGatherScatter(Params const& params_) : params(params_) { } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + (void) smem_buf; + ThreadEpilogueOp epilogue_op{params.thread_params}; + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + + // Represent the full output tensor + Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C); // (m,n,l) + Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l) + + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; +}; + +} // namespace cutlass::epilogue::collective + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_kernel.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0cb1aad901891867c87c62c999be4e13de94d598 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_kernel.cuh @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple permutation kernel implementation. +*/ + +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" +#include "cute/numeric/numeric_types.hpp" + +namespace example +{ + +/** + * Assumes column-major input (M mode is contiguous, N mode is strided). + * For row major, the inputs must be switched accordingly. +*/ +template +__global__ void +permute_kernel(Element const* __restrict__ input, + Element* __restrict__ output, + Permute permute, + int64_t num_elems, + cutlass::FastDivmod stride_divmod) +{ + // CUTLASS 2.x batched permute functions assume 0 batch stride for target tensor + Element const * input_b = input + blockIdx.z * num_elems; + Element * output_b = output + (Batched ? 0 : blockIdx.z * num_elems); + for (int64_t k = threadIdx.x + blockIdx.x * blockDim.x; k < num_elems; k += blockDim.x * gridDim.x) + { + int i, j; + stride_divmod(j, i, k); + output_b[permute(cutlass::PitchLinearCoord(i, j))] = input_b[i + j * stride_divmod.divisor]; + } +} + +template +void permute(Element const* input, + Element * output, + int64_t num_elems, + int stride, + int batch_count, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int64_t num_elems_upcast = num_elems / factor; + Permute permute_upcast(cutlass::PitchLinearCoord(stride_upcast, int(num_elems_upcast/stride_upcast)), stride_upcast); + + cutlass::FastDivmod stride_divmod(stride); + dim3 blocks(hw_info.sm_count, 1, batch_count); + permute_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + permute_upcast, + num_elems_upcast, + stride_upcast); +} + +} // namespace example diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_traits.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_traits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e7a45d7281b2b9d88414a7e37c92f20f955ebb40 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_traits.hpp @@ -0,0 +1,274 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Additional permutation information for the example. +*/ + +#include "cutlass/layout/permute.h" +#include "cutlass/gemm/gemm.h" + +namespace example +{ + +using namespace cute; + +// This struct is specialized below for different CUTLASS 2.x permutation ops +// to describe the operation in terms of target CuTe shape and stride order. +template +struct PermuteTraits {}; + +// Use X as a placeholder for shape division result +using X = Underscore; + +// Reshape a rank-2 shape into a multidimensional shape. +// Input: +// shape = (A, B, ...) +// target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...) +// Output: +// ((A1, ..., A/prod(A1..Am), ..., Am), (B1, ..., B/prod(B1..Bn), ..., Bn), ...) +template +constexpr auto +reshape(Shape const& shape, TargetShape const& target_shape) +{ + if constexpr (is_tuple::value) { + return cute::transform(shape, target_shape, [](auto && s, auto && t){ return reshape(s, t); }); + } + else { + auto idx = find_if(target_shape, [](auto x){ return is_underscore{}; }); + constexpr int I = decltype(idx)::value; + static_assert(I < tuple_size_v, "Each mode of TargetShape must contain a placeholder X"); + auto divisors = remove(target_shape); + assert(shape % product(divisors) == 0); + return replace(target_shape, shape / product(divisors)); + } +} + +// Given a tensor layout, compute a permutation layout consisting of: +// - sub-modes corresponding to the implied multidimensional shape of the source tensor +// - strides accounting for the permutation operation being performed +template +constexpr auto +make_permute_layout(Layout const& layout) { + static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); + if constexpr (Transpose) { + // Deal with tensor B by transposing appropriately before and after computing the permute layout. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + return select<1,0,2>(make_permute_layout(select<1,0,2>(layout))); + } + else { + if constexpr (cutlass::layout::is_trivial_permute) { + // Special case for NoPermute. Use a depth-2 layout for consistency with other permutations. + using ShapeProfile = tuple, tuple, tuple>; + return unflatten(layout, ShapeProfile{}); + } + else { + // Here's where the permutation layout is actually built + using ShapeProfile = typename PermuteTraits::ShapeProfile; + using StrideOrder = typename PermuteTraits::StrideOrder; + return make_ordered_layout(reshape(layout.shape(), ShapeProfile{}), StrideOrder{}); + } + } +} + +namespace detail +{ + +template +struct is_constant_pred { + template + constexpr auto operator()(T) { + return is_constant{}; + } +}; + +template +constexpr auto +inverse_impl(Permutation const & perm, seq) { + return cute::make_tuple(Int{})>{}...); +} + +} // namespace detail + +// Compute an inverse of a permutation represented as a tuple of cute::Int<> +template +constexpr auto +inverse(Permutation const & perm) { + auto flat_perm = flatten(perm); + return unflatten(detail::inverse_impl(flat_perm, tuple_seq{}), perm); +} + +template +using inverse_t = decltype(inverse(T{})); + +// Given a rank-2 layout of tensor that is assumed to have been permuted, +// compute the original rank-2 layout of the tensor prior to the permutation. +// This is needed to form the correct input to the standalone permutation kernel. +template +constexpr auto +make_original_layout(Layout const& layout) { + static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); + if constexpr (Transpose) { + // Deal with tensor B by transposing appropriately before and after computing the permute layout. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + return select<1,0,2>(make_original_layout(select<1,0,2>(layout))); + } + else { + using ShapeProfile = typename PermuteTraits::ShapeProfile; + auto re_shape = flatten(reshape(layout.shape(), ShapeProfile{})); + using IndexOrder = typename PermuteTraits::IndexOrder; + auto orig_shape = transform_leaf(IndexOrder{}, [&](auto i){ return get(re_shape); }); + using OrigOrder = conditional_t(), seq<0,1,2>, seq<1,0,2>>; + // print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n"); + // print("Original shape: "); print(orig_shape); print("\n"); + return make_ordered_layout(product_each(orig_shape), OrigOrder{}); + } +} + +/////////////// Tensor4DPermute0213 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,X>, Shape>; + using IndexOrder = Step, Step<_1,_3>, Step<_4>>; + using StrideOrder = inverse_t; // Step, Step<_1,_3>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,X>, Shape>; + using IndexOrder = Step, Step<_1,_3>, Step<_4>>; + using StrideOrder = inverse_t; // Step, Step<_1,_3>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape>, Shape>; + using IndexOrder = Step, Step<_0,_2>, Step<_4>>; + using StrideOrder = Step, Step<_0,_2>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape>, Shape>; + using IndexOrder = Step, Step<_0,_2>, Step<_4>>; + using StrideOrder = Step, Step<_0,_2>, Step<_4>>; +}; + +/////////////// Tensor4DPermuteBMM0321 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape, Shape,X>>; + using IndexOrder = Step, Step<_1>, Step<_3>>; + using StrideOrder = Step, Step<_2>, Step<_1,_3>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape>, Shape, Shape>; + using IndexOrder = Step, Step<_2>, Step<_1,_3>>; + using StrideOrder = Step, Step<_1>, Step<_3>>; +}; + +/////////////// Tensor4DPermuteBMM0213 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape, Shape,X>>; + using IndexOrder = Step, Step<_1,_2>, Step<_3>>; + using StrideOrder = Step, Step<_0>, Step<_1,_3>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape>, Shape>; + using IndexOrder = Step, Step<_1>, Step<_2,_3>>; + using StrideOrder = Step, Step<_0,_2>, Step<_3>>; +}; + +/////////////// Tensor5DPermute02413 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int,X>, Shape>; + using IndexOrder = Step, Step<_4,_1,_3>, Step<_5>>; + using StrideOrder = inverse_t; // Step, Step<_1,_4,_2>, Step<_5>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_1,_4,_2>, Step<_5>>; + using StrideOrder = inverse_t; // Step, Step<_4,_1,_3>, Step<_5>>; +}; + +/////////////// Tensor5DPermute20314 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_3,_1,_4>, Step<_5>>; + using StrideOrder = Step, Step<_0,_2,_4>, Step<_5>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_2,_4,_1>, Step<_5>>; + using StrideOrder = Step, Step<_0,_3,_1>, Step<_5>>; +}; + +} // namespace example diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e8ea5330b3163a26e09be16e26f459b02cf5bf02 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +template +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f; + bool device_scale = false; + bool save_aux = true; + bool save_amax = true; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + RasterOrderOptions raster; + int swizzle; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("scale_a", scale_a, 1.f); + cmd.get_cmd_line_argument("scale_b", scale_b, 1.f); + cmd.get_cmd_line_argument("scale_c", scale_c, 1.f); + cmd.get_cmd_line_argument("scale_d", scale_d, 1.f); + cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f); + cmd.get_cmd_line_argument("device_scale", device_scale, false); + cmd.get_cmd_line_argument("save_aux", save_aux, true); + cmd.get_cmd_line_argument("save_amax", save_amax, true); + cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "54_fp8_hopper_warp_specialized_gemm\n\n" + << " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --scale_a= Scaling factor for A\n" + << " --scale_b= Scaling factor for B\n" + << " --scale_c= Scaling factor for C\n" + << " --scale_d= Scaling factor for D (ignored for non-fp8 D)\n" + << " --scale_aux= Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n" + << " --device_scale= Copy scalars to device memory before kernel launch (default: false)\n" + << " --save_aux= Save the pre-activation as an auxiliary tensor (default: true)\n" + << " --save_amax= Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..98c7440679d455c12306c8e07078aa670f47328b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cute/tensor.hpp" + +#include +#include +#include "helper.h" + +enum MixedDtypeGemmMode { + ConvertOnly, + ScaleOnly, + ScaleWithZeroPoint +}; + +/// Command line options parsing +struct MixedDtypeOptions { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 100; + int warmup = 10; + int mode = 1; + int m = 5120, n = 4096, k = 4096; + int g = 128; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("g", g); + cmd.get_cmd_line_argument("mode", mode); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("warmup", warmup); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "55_hopper_mixed_dtype_gemm\n\n" + << " Hopper Mixed Data Type GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --g= The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --mode= The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --warmup= Number of warmup iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "55_hopper_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 -g=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct MixedDtypeResult +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +/// Profiling Loop +template +void mixed_dtype_profiling( + Gemm& gemm, + MixedDtypeOptions const& options, + MixedDtypeResult& result) { + + if (options.iterations <= 0) return; + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + std::vector runtimes; + runtimes.reserve(options.iterations); + + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } + } + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + // Compute average setup and runtime and GFLOPs. + result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + +} + +/// Helpers to initialize a block of device data +template +bool initialize_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_scale( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + // If no scales, initialize with 1 so we can use the same kernel to dequantize the data + float scope_max = 1.0f, scope_min = 1.0f; + if (options.mode != MixedDtypeGemmMode::ConvertOnly) { + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + scope_max = 2.f; + scope_min = 0.1f; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_zero( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + // If no bias, initialize with 0 so we can use the same kernel to dequantize the data + float scope_max = 0.0f, scope_min = 0.0f; + if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + scope_max = 2.0f; + scope_min = -2.0f; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d3d0982d14cc84c9d3bcf0cace2973a0ca931302 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include + +#include "cutlass/util/print_error.hpp" + +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" + +using namespace cute; + +struct AmpereUnpredicatedFprop { + // + // Static config for conv problem shape + // + using D = _6; + using H = _4; + using W = _4; + + using T = _3; + using R = _3; + using S = _3; + + using Z = _4; + using P = _2; + using Q = _2; + + using C = _64; + using K = _128; + + // Tiler config + using Tiler_K = decltype(cute::min(K{}, _128{})); + using Tiler_C = decltype(cute::min(C{}, _32{})); + using Tiler_N = _4; + using TileM = Tiler_K; + using TileN = Shape; + using TileK = Shape; + using PIPE = _3; + using TilerFlt = Shape; + using TilerAct = Shape; + using TilerOut = Shape; + + using TileSizeM = Int; + using TileSizeN = Int; + using TileSizeK = Int; + static constexpr int Stages = PIPE::value; + + using ElementFlt = tfloat32_t; + using ElementAct = tfloat32_t; + using ElementOut = float; + + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, + Tile<_32,_32,Underscore>>; + + static constexpr int MaxThreadsPerBlock = size(TiledMma{}); + static constexpr int MinBlocksPerMultiprocessor = 1; + + union SharedStorage { + struct { + ElementFlt sAMatrix[size(TileM{}) * size(TileK{}) * size(PIPE{})]; + ElementAct sBMatrix[size(TileN{}) * size(TileK{}) * size(PIPE{})]; + } mainloop; + + struct { + ElementOut sCMatrix[size(TileM{}) * size(TileN{})]; + } epilogue; + }; + + // + // Stencil tensor + // + + using GmemLayoutFlt = decltype(make_ordered_layout( + Shape< K, Shape< C, T, R, S>>{}, + tuple<_4, tuple<_0,_3,_2,_1>>{})); + + // We have 64 elements * 32b each in the major mode that we can vectorize + // Max vector size is 128b, so lay 16 threads along the major mode with a vector size of 4 + // Rest along the minor mode + using GmemTiledCopyFlt = decltype(make_tiled_copy( + Copy_Atom, ElementFlt>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + + // Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses + // using SmemLayoutFlt = decltype( + // composition(Swizzle<3,2,3>{}, + // make_ordered_layout( + // Shape{}, + // tuple< _1, _0, _2>{}))); + + using SmemLayoutAtomFlt = decltype( + composition(Swizzle<1,2,3>{}, + Layout>, + Stride<_4,Stride<_1,_32>>>{})); + + using SmemCopyAtomFlt = Copy_Atom; + + // + // Activation tensor + // + + // Activation tensor is major in the contraction mode, so vectorize that mode first + // Then lay out the rest of the threads along the other mode + using GmemTiledCopyAct = decltype(make_tiled_copy( + Copy_Atom, ElementAct>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + + // Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses + // using SmemLayoutAct = decltype( + // composition(Swizzle<3,2,3>{}, + // make_ordered_layout( + // Shape{}, + // tuple< _1, _0, _2>{}))); + + using SmemLayoutAtomAct = decltype( + composition(Swizzle<1,2,3>{}, + Layout>, + Stride<_4,Stride<_1,_32>>>{})); + + using SmemCopyAtomAct = Copy_Atom; + + // + // Output tensor + // + + using GmemTiledCopyOut = decltype(make_tiled_copy( + Copy_Atom, ElementAct>{}, + Layout, + Stride<_1, _8>>{}, + Layout>{})); + + using SmemCopyAtomOut = Copy_Atom, ElementOut>; + + // This can be optimized to make accesses BCF, but we use a col-major layout here to show off composability + using SmemLayoutOut = Layout>; + + // + // Conv functor + // + template + void __device__ + operator()(cute::Tensor mFlt, // ( K, (C,T,R,S)) + TensorActivation mAct, // ((N,Z,P,Q), (C,T,R,S)) + TensorOutput mOut, // ( K, (N,Z,P,Q)) + char* smem_buf) const { + using namespace cute; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm80CpAsyncUnpredicated, + Shape, + ElementFlt, + Underscore, // Ignore the stride, we are passing full cute::Tensor to operator() + ElementAct, + Underscore, // Ignore the stride, we are passing full cute::Tensor to operator() + TiledMma, + GmemTiledCopyFlt, + SmemLayoutAtomFlt, + SmemCopyAtomFlt, + cute::identity, + GmemTiledCopyAct, + SmemLayoutAtomAct, + SmemCopyAtomAct, + cute::identity>; + + TiledMma tiled_mma; + Tensor accum = partition_fragment_C(tiled_mma, TilerOut{}); + clear(accum); + + // Set up tensors + // NOTE: blockIdx.x projects onto act-NDHW mode, y along the flt-K mode for the sake of higher dynamic range in NDHW + Tensor gA_mk = local_tile(mFlt, TilerFlt{}, make_coord(_,_)); // (BLK_M,BLK_K,m',k') + Tensor gB_nk = local_tile(mAct, TilerAct{}, make_coord(_,_)); // (BLK_N,BLK_K,n',_1) + Tensor gC_mn = local_tile(mOut, TilerOut{}, make_coord(_,_)); // (BLK_M,BLK_N,m',n') + + // Compute m_coord and n_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.y), shape<2>(gA_mk)); + auto n_coord = idx2crd(int(blockIdx.x), shape<2>(gB_nk)); + Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k') + Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,_1) + Tensor gC = gC_mn(_,_,m_coord,n_coord); // (BLK_M,BLK_N) + + auto k_tile_iter = cute::make_coord_iterator(size<2>(gA)); + int k_tile_count = size<2>(gA); + + CollectiveMainloop collective_mma; + collective_mma( + accum, + gA, + gB, + accum, + k_tile_iter, k_tile_count, + Underscore{}, // no residue since we do not support predication + threadIdx.x, + smem_buf); + + // + // Epilogue + // + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sC = make_tensor(make_smem_ptr(&storage.epilogue.sCMatrix[0]), SmemLayoutOut{}); + + auto smem_tiled_copy_C = make_tiled_copy_C(SmemCopyAtomOut{}, tiled_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_slice(threadIdx.x); + auto tCrC = smem_thr_copy_C.retile_S(accum); + auto tCsC = smem_thr_copy_C.partition_D(sC); + copy(smem_tiled_copy_C, tCrC, tCsC); + + __syncthreads(); + + GmemTiledCopyOut gmem_tiled_copy_C; + auto gmem_thr_copy_C = gmem_tiled_copy_C.get_slice(threadIdx.x); + auto tDsC = gmem_thr_copy_C.partition_S(sC); + auto tDgC = gmem_thr_copy_C.partition_D(gC); + copy(gmem_tiled_copy_C, tDsC, tDgC); + + #if 0 + if (thread0()) { + print("mAct = "); print(mAct); print('\n'); + print("mFlt = "); print(mFlt); print('\n'); + print("mOut = "); print(mOut); print('\n'); + print("gA = "); print(gA); print('\n'); + print("gB = "); print(gB); print('\n'); + print("gC = "); print(gC); print('\n'); + print("sA = "); print(sA.layout()); print('\n'); + print("sB = "); print(sB.layout()); print('\n'); + print("sC = "); print(sC.layout()); print('\n'); + print("tAgA = "); print(tAgA.layout()); print('\n'); + print("tBgB = "); print(tBgB.layout()); print('\n'); + print("tAsA = "); print(tAsA.layout()); print('\n'); + print("tBsB = "); print(tBsB.layout()); print('\n'); + print("tCsA = "); print(tCsA.layout()); print('\n'); + print("tCsB = "); print(tCsB.layout()); print('\n'); + print("tCrC = "); print(tCrC.layout()); print('\n'); + print("tCsC = "); print(tCsC.layout()); print('\n'); + print("tDsC = "); print(tDsC.layout()); print('\n'); + print("tDgC = "); print(tDgC.layout()); print('\n'); + print("gmem tiled copy A = "); print(gmem_tiled_copy_A); print('\n'); + print("gmem tiled copy B = "); print(gmem_tiled_copy_B); print('\n'); + print("gmem tiled copy C = "); print(gmem_tiled_copy_C); print('\n'); + print("k_tile_count = "); print(size<2>(gA)); print('\n'); + print("k_tile_iter = "); print(*k_tile_iter); print('\n'); + print("K_BLOCK_MAX = "); print(K_BLOCK_MAX); print('\n'); + } + #endif + } +}; + +template +inline int +fprop_reference( + TensorFlt mStencil, // Logical MK: ( K, (C,T,R,S)) + TensorAct mActivation, // Logical NK: ((N,Z,P,Q), (C,T,R,S)) + TensorOut mOutput, // Logical MN: ( K, (N,Z,P,Q)) + TensorOut mOutputRef) { + int32_t N = size<1,0>(mOutputRef); + int32_t Z = size<1,1>(mOutputRef); + int32_t P = size<1,2>(mOutputRef); + int32_t Q = size<1,3>(mOutputRef); + int32_t T = size<1,3>(mStencil); + int32_t R = size<1,2>(mStencil); + int32_t S = size<1,1>(mStencil); + int32_t C = size<1,0>(mStencil); + + size_t K = static_cast(size<0>(mOutputRef)); + size_t NZPQ = static_cast(size<1>(mOutputRef)); + size_t CTRS = static_cast(size<1>(mStencil)); + +#if defined(_OPENMP) + #pragma omp parallel for +#endif + for (size_t logical_m = 0; logical_m < K; ++logical_m) { + for (size_t logical_n = 0; logical_n < NZPQ; ++logical_n) { + auto accumulator = float(0); + for (size_t logical_k = 0; logical_k < CTRS; ++logical_k) { + accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k); + } + mOutputRef(logical_m, logical_n) = accumulator; + } + } + + return print_relative_error(mOutput, mOutputRef, /*print_verbose*/ false, /*print_error*/ true, /*error_margin*/ 0.01); +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..09bb9b6599062a8bc69cd7474a361d031943c652 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "dispatch_policy_extra.hpp" +#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp" +#include "../pipeline/prefetch_pipeline_sm90.hpp" + +namespace cutlass::gemm::collective { + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_prefetch(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_prefetch(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto prefetch_pipeline_bytes = sizeof(typename cutlass::detail::PrefetcherPipelineSharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr int MK_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); //also the prefetch smem size + constexpr int NK_bytes = cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); + constexpr int stage_bytes = MK_bytes + NK_bytes + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes - MK_bytes * PrefetchStagesActual - prefetch_pipeline_bytes) / stage_bytes; +} + +} // namespace detail + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = Layout>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = Layout>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8bf0bf37202eaadd2f8e7f0573c196c7fa77a7be --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +namespace cutlass::gemm { + +// Standard non-persistent kernel with a single producer warp, and one prefetch warp. +// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A` +// while the producer warp is waiting on griddepcontrol. +// GDC `launch_dependent_grids` is issued from the producer warp instead of math warps, and +// according to prefetch ratio. +struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetch { }; + +// Non-persistent kernel with two producer warps (one for each of A and B), and one prefetch warp. +// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A` +// while the producer warp for `B` is waiting on griddepcontrol. Producer warp for `A` does not +// wait on griddepcontrol and loads immediately. +struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA { }; + +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelTmaWarpSpecializedFP8FastAccumWithPrefetch +> +struct MainloopSm90TmaGmmaWarpSpecializedWithPrefetch { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; +}; + +} // namespace cutlass::gemm diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d9325b4e20d92587fa9f70e607223c4d7b8cf47 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp @@ -0,0 +1,871 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/arch/grid_dependency_control.h" + +#include "dispatch_policy_extra.hpp" + +#include "../pipeline/prefetch_pipeline_sm90.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +constexpr int PrefetchStages = 4; +constexpr int PrefetchInitialStages = 1; +// This determines how much shmem we set aside for prefetch. +// We don't reuse anything loaded by prefetcher, so we can keep +// loading into the same place -- there will be a conflict when +// writing, but it doesn't affect performance as much as the doors +// that this opens. +constexpr int PrefetchStagesActual = 1; + +} // namespace detail + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedWithPrefetch, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1"); + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + using PrefetcherPipeline = cutlass::PrefetchPipeline; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(rank(SmemLayoutA{}) == 3 && size<2>(SmemLayoutA{}) == DispatchPolicy::Stages); + static_assert(rank(SmemLayoutB{}) == 3 && size<2>(SmemLayoutB{}) == DispatchPolicy::Stages); + + using PrefetchSmemLayoutA = decltype(make_layout(make_shape( + cute::Int(SmemLayoutA{})>{}, + cute::Int(SmemLayoutA{})>{}, + cute::Int{}))); + + static constexpr auto prefetch_smem_size = cute::cosize_v; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + // Defined outside the class where it's used, to work around MSVC issues + using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned smem_prefetch; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + PrefetcherPipelineStorage prefetcher_pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + float overlap_ratio = 0.5; + float prefetch_ratio = -1.0; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + float overlap_ratio = 0.5; + float prefetch_ratio = -1.0; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.overlap_ratio, + args.prefetch_ratio + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + return false; + } + + if (args.overlap_ratio > 1.0) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `overlap_ratio` must be either negative (disabled) or in [0, 1].\n"); + return false; + } + + if (args.prefetch_ratio > 1.0) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `prefetch_ratio` must be either negative (disabled) or in [0, 1].\n"); + return false; + } + + return true; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + TensorB const& gB_nkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool disable_gdc = mainloop_params.overlap_ratio < 0.0; + float overlap_ratio = mainloop_params.overlap_ratio; + int launch_dep_grids_threshold = static_cast(static_cast(k_tile_count - 1) * overlap_ratio); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + // Applies the mapping from cta_tma_b + Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // We have to wait on dependent grids because of B. + cutlass::arch::wait_on_dependent_grids(); + + // Signal prefetcher to stop + prefetcher_pipeline.producer_arrive(); + + bool launch_dep_grids = false; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + launch_dep_grids = true; + cutlass::arch::launch_dependent_grids(); + } + + // Advance smem_pipe_write + ++smem_pipe_write; + } + if (!disable_gdc && !launch_dep_grids) { + cutlass::arch::launch_dependent_grids(); + } + } + } + + template < + class TensorA, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_MK( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool disable_gdc = mainloop_params.overlap_ratio < 0.0; + float overlap_ratio = mainloop_params.overlap_ratio; + int launch_dep_grids_threshold = static_cast(static_cast(k_tile_count - 1) * overlap_ratio); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + // Don't wait on dependent grids when loading `A`, because + // we assume `A` (weights) are static. + + bool launch_dep_grids = false; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + ++k_tile_iter; + + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + launch_dep_grids = true; + cutlass::arch::launch_dependent_grids(); + } + + // Advance smem_pipe_write + ++smem_pipe_write; + } + if (!disable_gdc && !launch_dep_grids) { + cutlass::arch::launch_dependent_grids(); + } + } + } + + template < + class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_NK( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorB const& gB_nkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from cta_tma_b + Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + // Signal prefetcher to stop + prefetcher_pipeline.producer_arrive(); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + + template < + class TensorA, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + prefetch_MK( + Params const& mainloop_params, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0; + float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio; + int prefetch_iters = static_cast(static_cast(k_tile_count) * 0.5 * prefetch_ratio); + prefetch_iters = min(k_tile_count, ((prefetch_iters + detail::PrefetchStages - 1) / detail::PrefetchStages) * detail::PrefetchStages); + + Tensor sA = make_tensor( + make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + uint32_t prefetcher_stage = 0; + uint32_t prefetcher_phase = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt = 0 ; cnt < prefetch_iters; ++cnt) { + + if (do_best_effort_prefetch && prefetcher_pipeline.have_producers_arrived()) { + break; + } + + prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= detail::PrefetchStages); + using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType; + BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage); + + int write_stage = 0; + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + ++k_tile_iter; + ++k_tile_iter; + + prefetcher_pipeline.advance_prefetcher_state(prefetcher_stage, prefetcher_phase); + } + prefetcher_pipeline.prefetcher_tail(prefetcher_stage, prefetcher_phase); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8c85edb71bb7396ee3b83afd1fe9d56510f1a759 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + float overlap_ratio = 0.5f, prefetch_ratio = 0.5f; + int iterations = 1000; + int n = 64, m = 1280, k = 8192, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("p", prefetch_ratio, 0.5f); + cmd.get_cmd_line_argument("o", overlap_ratio, 0.5f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "63_hopper_gemm_with_weight_prefetch\n\n" + << " Hopper FP8 GEMM using a non-persistent kernel with L2 weight prefetch. \n" + << " For more details please refer to the source file.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --p= Prefetch ratio\n" + << " --o= Overlap ratio\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "63_hopper_gemm_with_weight_prefetch" << + " --m=1024 --n=512 --k=1024 --o=0.5 --p=0.5 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + + /// Compute effective bandwidth in GB/sec + double effective_bandwidth( + double runtime_s, + size_t bytes_a, + size_t bytes_b, + size_t bytes_c, + size_t bytes_d + ) const + { + static double const kBytesPerGiB = double(1ull << 30); + + double bytes_in = + (double)(l) * (double)(m) * (double)(k) * (double)(bytes_a) + // A + (double)(l) * (double)(n) * (double)(k) * (double)(bytes_b) + // B + (beta != 0.f ? (double)(l) * (double)(m) * (double)(n) * (double)(bytes_c) : 0.f); // C + double bytes_out = (double)(l) * (double)(m) * (double)(n) * (double)(bytes_d); // D + + double gb_total = (bytes_in + bytes_out) / kBytesPerGiB; + return gb_total / runtime_s; + } +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp new file mode 100644 index 0000000000000000000000000000000000000000..73655ad2ace6674ed0ea7acb5b185505379d599e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp @@ -0,0 +1,561 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" + +#include "../collective/dispatch_policy_extra.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +// GEMM + Prefetch for the A tensor + (optional) split DMA warps +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cute::is_same_v || + cute::is_same_v + > +> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + static constexpr bool SplitWarps = cute::is_same_v; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "TMA warp-specialized kernel does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + // Kernel level shared memory storage + struct SharedStorage { + // Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using PrefetcherPipelineStorage = typename CollectiveMainloop::PrefetcherPipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) PrefetcherPipelineStorage prefetcher; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 1; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + return implementable; + } + + static + size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_tiled_cta_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + +#if defined(__CUDA_ARCH_FEAT_SM90_ALL) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + // Split mode: use Warp0 to load NK and epilogue, Warp2 to load MK. + // Non-split mode: use Warp0 to load MK, NK and epilogue, Warp2 is unused. + // Both modes use Warp1 to prefetch. + enum class ProducerWarpRole { + Warp0 = 0, + PrefetchMK = 1, + Warp2 = 2, + UnusedWarp = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + if (warp_group_role == WarpGroupRole::Producer && ( + producer_warp_role == ProducerWarpRole::Warp0 || + producer_warp_role == ProducerWarpRole::Warp2)) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + bool should_prefetch = params.mainloop.prefetch_ratio > 0; + using PrefetcherPipeline = typename CollectiveMainloop::PrefetcherPipeline; + typename PrefetcherPipeline::Params prefetcher_pipeline_params; + prefetcher_pipeline_params.num_prefetchers = 1; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { + prefetcher_pipeline_params.should_prefetch = should_prefetch; + prefetcher_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes_mk; + } + PrefetcherPipeline prefetcher_pipeline(shared_storage.pipelines.prefetcher, prefetcher_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp0) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + // Non-prefetcher warps arrive and wait, + // Prefetcher warp can go ahead without waiting. + cute::cluster_arrive_relaxed(); + if (warp_group_role != WarpGroupRole::Producer || + producer_warp_role != ProducerWarpRole::PrefetchMK) { + cute::cluster_wait(); + } + return [] () {}; + } + else { + // __syncthreads() but only for non prefetcher warps + if (should_prefetch) { + + // Use a named barrier to let the prefetcher warp start loading into the L2 + // without waiting to sync with all other warps. + // All other warps need to sync because the mainloop pipeline init + // should be visible to all of them. + // Prefetcher has its own barriers, and the only warps it would need to sync + // with would be the DMA warps. + using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier; + auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier( + blockDim.x * blockDim.y * blockDim.z, + /*id*/ 0); + // Prefetcher warp doesn't arrive on this barrier. + auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier( + blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp, + /*id*/ 1); + + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { + __syncwarp(); + prefetcher_arrive_barrier.arrive(); + } + else if (warp_group_role == WarpGroupRole::Producer) { + prefetcher_arrive_barrier.arrive_and_wait(); + cluster_arrive_barrier.arrive_and_wait(); + } + else { + prefetcher_arrive_barrier.arrive(); + cluster_arrive_barrier.arrive_and_wait(); + } + } else { + __syncthreads(); + } + return [] () {}; + } + } (); + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get pipeline iterators and increments from tensor shapes + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + if (producer_warp_role == ProducerWarpRole::Warp0) { + if constexpr(SplitWarps) { + collective_mainloop.load_NK( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gB_nkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + else { + collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, gB_nkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + // Ensure warp is converged before issuing epilogue loads + __syncwarp(); + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (SplitWarps && producer_warp_role == ProducerWarpRole::Warp2) { + collective_mainloop.load_MK( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } else if (producer_warp_role == ProducerWarpRole::PrefetchMK && should_prefetch) { + collective_mainloop.prefetch_MK( + params.mainloop, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state_next, + epi_store_pipeline, + epi_store_pipe_producer_state_next + ); + } +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..157d201adec5ecd36ad76dedefd72ac3404d54a6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/container/array.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +namespace detail { + +// MSVC work-around +template +struct PrefetcherPipelineSharedStorage { + using TransactionBarrier = cutlass::arch::ClusterTransactionBarrier; + using Barrier = cutlass::arch::ClusterBarrier; + + TransactionBarrier tma_barrier[Stages]; + Barrier producer_ready_barrier; +}; + +} // end namespace detail + +using namespace cute; + +// Prefetcher pipeline is modeled after PipelineTmaAsync, with a cluster transaction +// barrier providing control over the number of concurrent outstanding TMA loads. +// There is also an additional cluster barrier which is only used when `prefetch_ratio` is unset. +// `prefetch_ratio` determines how many K tiles get loaded, and when unset, the prefetcher checks +// whether DMA warps are done waiting on griddepcontrol, and if so, stops issuing more TMA loads. +template +class PrefetchPipeline { +public : + static constexpr uint32_t Stages = Stages_; + using SharedStorage = detail::PrefetcherPipelineSharedStorage; + + using TransactionBarrier = typename SharedStorage::TransactionBarrier; + using Barrier = typename SharedStorage::Barrier; + using PrefetcherBarrierType = typename TransactionBarrier::ValueType; + + struct Params { + uint32_t transaction_bytes = 0; + uint32_t num_prefetchers = 1; + bool should_prefetch = false; + }; + + // Constructor + CUTLASS_DEVICE + PrefetchPipeline(SharedStorage& storage, Params params) + : params_(params) + , tma_barrier_ptr_(&storage.tma_barrier[0]) + , producer_ready_barrier_ptr_(&storage.producer_ready_barrier) { + + int lane_predicate = cute::elect_one_sync(); + if (params.should_prefetch && lane_predicate) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; ++i) { + tma_barrier_ptr_[i].init(params.num_prefetchers); + } + producer_ready_barrier_ptr_[0].init(1); + } + } + + CUTLASS_DEVICE + void producer_arrive() { + if (params_.should_prefetch) { + producer_ready_barrier_ptr_[0].arrive(); + } + } + + CUTLASS_DEVICE + bool have_producers_arrived() { + if (params_.should_prefetch) { + uint32_t barrier_status_ = producer_ready_barrier_ptr_[0].try_wait(0); + auto barrier_status = static_cast(barrier_status_); + if (barrier_status == BarrierStatus::WaitDone) { + return true; // exit prefetcher loop + } + return false; + } + return true; + } + + CUTLASS_DEVICE + void prefetcher_acquire(uint32_t stage, uint32_t phase, bool should_wait) { + if (params_.should_prefetch) { + if (should_wait) { + tma_barrier_ptr_[stage].wait(phase ^ 1); + } + tma_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes); + } + } + + CUTLASS_DEVICE + void advance_prefetcher_state(uint32_t& stage, uint32_t& phase) { + if (params_.should_prefetch) { + stage++; + if (stage == Stages) { + stage = 0; + phase ^= 1; + } + } + } + + CUTLASS_DEVICE + void prefetcher_tail(uint32_t stage, uint32_t phase) { + if (params_.should_prefetch) { + // Wait on any already-issued loads + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < stage; ++i) { + tma_barrier_ptr_[i].wait(phase); + } + } + } + + CUTLASS_DEVICE + PrefetcherBarrierType* prefetcher_get_barrier(uint32_t stage) { + return reinterpret_cast(&tma_barrier_ptr_[stage]); + } + +private : + TransactionBarrier* tma_barrier_ptr_ = nullptr; + Barrier* producer_ready_barrier_ptr_ = nullptr; + Params params_; + +}; + +} // end namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..85aff7567234e9812d0cd09f5a85f9fec1840bdc --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -0,0 +1,140 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +template +struct Options { + + bool help = false; + bool verify = true; + + float alpha = 1.f, beta = 0.f; + float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f; + bool device_scale = false; + bool save_aux = true; + bool save_amax = true; + int iterations = 1000; + int warmup = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + RasterOrderOptions raster; + int swizzle; + float epsilon = 0.02f; + float non_zero_floor = 1.f; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("scale_a", scale_a, 1.f); + cmd.get_cmd_line_argument("scale_b", scale_b, 1.f); + cmd.get_cmd_line_argument("scale_c", scale_c, 1.f); + cmd.get_cmd_line_argument("scale_d", scale_d, 1.f); + cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f); + cmd.get_cmd_line_argument("device_scale", device_scale, false); + cmd.get_cmd_line_argument("save_aux", save_aux, true); + cmd.get_cmd_line_argument("save_amax", save_amax, true); + cmd.get_cmd_line_argument("warmup", warmup); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("verify", verify); + cmd.get_cmd_line_argument("epsilon", epsilon); + cmd.get_cmd_line_argument("non-zero-floor", non_zero_floor); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling\n\n" + << " Hopper FP8 GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --scale_a= Scaling factor for A\n" + << " --scale_b= Scaling factor for B\n" + << " --scale_c= Scaling factor for C\n" + << " --scale_d= Scaling factor for D (ignored for non-fp8 D)\n" + << " --scale_aux= Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n" + << " --device_scale= Copy scalars to device memory before kernel launch (default: false)\n" + << " --save_aux= Save the pre-activation as an auxiliary tensor (default: true)\n" + << " --save_amax= Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --verify= Verify the results.\n\n" + << " --epsilon= The epsilon value for comparing the results.\n\n" + << " --non-zero-floor= The none zero floor for comparing the results.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff446a91353ab2b2acd27528741a8ac076e90b3f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -0,0 +1,270 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; +template +struct Options { + using ProblemShape = _ProblemShape; + + bool help = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, groups = 10; + std::string benchmark_path; + std::vector problem_sizes_after_alignment_host; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + int const k_alignment = 128; + int const m_alignment = 128; + int const n_alignment = 128; + + RasterOrderOptions raster_order; + int swizzle; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster_order = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster_order = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster_order = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_after_alignment_host.clear(); + problem_sizes_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_after_alignment_host.reserve(groups); + problem_sizes_host.reserve(groups); + for (int i = groups; i > 0; i--) { + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + if (m < 0) { + m = m_alignment * (rand() % (64 * alignment / m_alignment)); + } + if (n < 0) { + n = n_alignment * (rand() % (64 * alignment / n_alignment)); + } + if (k < 0) { + k = k_alignment * (rand() % (32 * alignment / k_alignment)); + } + problem_sizes_after_alignment_host.push_back({m, n, k}); + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent_after_alignment, extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + extent.at(i) = x; + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent_after_alignment.at(i) = x; + } + + problem_sizes_after_alignment_host.push_back({extent_after_alignment.m(), extent_after_alignment.n(), extent_after_alignment.k()}); + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + groups = static_cast(problem_sizes_after_alignment_host.size()); + + return true; + } + + /// Calculate memory bandwidth statistics + template + auto gbps(double runtime_s) const { + double total_read_bytes = 0; + double total_write_bytes = 0; + + // Calculate bytes read and written for each problem + for (int i = 0; i < groups; ++i) { + auto problem = problem_sizes_host.at(i); + auto M = cute::get<0>(problem); + auto N = cute::get<1>(problem); + auto K = cute::get<2>(problem); + + if (M > 0) { // Only count active problems + // Matrix A: M*K elements read + total_read_bytes += M * K * sizeof(ElementA); + + // Matrix B: K*N elements read + total_read_bytes += K * N * sizeof(ElementB); + + // Matrix C: M*N elements read (for beta operation) + total_read_bytes += M * N * sizeof(ElementC); + + // Block scales for A and B + auto blockscale_shape = cute::shape(cute::get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{}))); + auto blockscale_m = cute::get<0>(blockscale_shape); + auto blockscale_n = cute::get<1>(blockscale_shape); + auto blockscale_k = cute::get<2>(blockscale_shape); + auto groupscale_m = blockscale_m * ScaleMsPerTile; + auto groupscale_n = blockscale_n * ScaleNsPerTile; + + total_read_bytes += groupscale_m * blockscale_k * sizeof(ElementBlockScale); // A scales + total_read_bytes += groupscale_n * blockscale_k * sizeof(ElementBlockScale); // B scales + + // Matrix D: M*N elements written + total_write_bytes += M * N * sizeof(ElementD); + } + } + + return (total_read_bytes + total_write_bytes) / 1.0e9 / runtime_s; + } + + double bandwidth_util(double eff_bandwidth) const { + int memoryClockRate; + int memoryBusWidth; + cudaDeviceGetAttribute(&memoryClockRate, cudaDevAttrMemoryClockRate, 0); + cudaDeviceGetAttribute(&memoryBusWidth, cudaDevAttrGlobalMemoryBusWidth , 0); + double bw = 2.0 * memoryClockRate * (memoryBusWidth / 8) / 1.0e6; + return eff_bandwidth / bw * 100.0; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\n\n" + << " Hopper FP8 Grouped GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" + << " --benchmark= Executes a benchmark problem size.\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Number of real-valued multiply-adds + uint64_t fmas = 0ull; + + for (auto const [m, n, k] : problem_sizes_host) { + fmas += static_cast(m) * + static_cast(n) * + static_cast(k); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8568b467dd87ca5c1b0111e6bd4584b9d3da0960 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include "../55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp" + +template +class GroupedMixedDtypeOptions : public MixedDtypeOptions { +public: + using ProblemShape = cutlass::gemm::GroupProblemShape>; + using UnderlyingProblemShape = typename ProblemShape::UnderlyingProblemShape; + + int groups = 6; + int c = 512; + std::string benchmark_path; + std::vector problem_sizes_host; + + GroupedMixedDtypeOptions() : MixedDtypeOptions() + { + m = 1024; + n = 2048; + k = 512; + }; + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + cmd.get_cmd_line_argument("c", c); + MixedDtypeOptions::parse(argc, args); + + problem_sizes_host = benchmark_path.empty() ? randomize_problems(cmd) : load_benchmark_problems(); + } + + std::ostream& print_usage(std::ostream& out) const { + out << "69_hopper_mixed_dtype_grouped_gemm\n\n" + << "Options:\n" + << " --help Display this usage statement\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --c= Sets the chunk size for scaling the quantized weights\n" + << " --groups= Sets the number of individual GEMM problems\n" + << " --mode= The mode to run the gemm\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations\n" + << " --warmup= Number of warmup iterations\n" + << " --benchmark= Executes a benchmark problem size\n"; + return out; + } + + double gflops(double runtime_s) const { + uint64_t fmas = std::accumulate(problem_sizes_host.begin(), problem_sizes_host.end(), 0ULL, + [](uint64_t sum, const UnderlyingProblemShape& problem) { + return sum + static_cast(cute::get<0>(problem)) * + static_cast(cute::get<1>(problem)) * + static_cast(cute::get<2>(problem)); + }); + return (2.0 * fmas) / (runtime_s * 1e9); + } + +private: + static constexpr int tma_alignment_bits = 128; + const int alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + std::vector randomize_problems(cutlass::CommandLine& cmd) { + std::vector problems; + problems.reserve(groups); + + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + for (int i = 0; i < groups; ++i) { + int m = (cmd_line_m >= 0) ? cmd_line_m : alignment * ((rand() % 64) + 1); + int n = (cmd_line_n >= 0) ? cmd_line_n : this->n; + int k = (cmd_line_k >= 0) ? cmd_line_k : this->k; + + if (k % alignment != 0) { + throw std::runtime_error("Error: k dimension must be a multiple of " + std::to_string(alignment)); + } + problems.push_back({m, n, k}); + } + return problems; + } + + std::vector load_benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file) { + throw std::runtime_error("Failed to open benchmark file: " + benchmark_path); + } + + std::vector problems; + int idx; + std::string extent_str; + + while (file >> idx >> extent_str) { + if (idx < 0 || extent_str.empty()) break; + + std::vector tokens; + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + cutlass::gemm::GemmCoord extent; + for (int i = 0; i < std::min(3, static_cast(tokens.size())); ++i) { + int x = std::stoi(tokens[i]); + extent.at(i) = (x % alignment) ? x + (alignment - (x % alignment)) : x; + } + + if (extent.product()) { + problems.push_back({extent.m(), extent.n(), extent.k()}); + } + } + groups = static_cast(problems.size()); + return problems; + } +}; + +template +void grouped_mixed_dtype_profiling( + Gemm& gemm, + const GroupedMixedDtypeOptions& options, + MixedDtypeResult& result, + const std::vector& alpha_host, + const std::vector& beta_host) { + + if (options.iterations <= 0) return; + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + std::vector runtimes; + runtimes.reserve(options.iterations); + + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } + } + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + std::cout << " Groups : " << options.groups << '\n' + << " Avg runtime : " << result.avg_runtime_ms << " ms\n" + << " GFLOPS : " << result.gflops << '\n'; +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_common.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c60d9e953e2ae4afee89269f0a405b5be19ff3b5 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_common.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = decltype(atom.accumulate_)::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template +CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, prepend(make_layout(stages), _)); +} + +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, TMs...>) { + + return TiledMMA>, + TAs...>, TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, + TAs...>, TMs...>) { + return TiledMMA, + TAs...>, TMs...>{}; +} + +template +CUTLASS_DEVICE +void warpgroup_reg_set() { + if constexpr (RegCount < 128) { + cutlass::arch::warpgroup_reg_dealloc(); + } + else { + cutlass::arch::warpgroup_reg_alloc(); + } +} + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_fusion.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a33ce2d2ce04c357fe49ac6752e99e5f6e8c32b6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -0,0 +1,396 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct NoMask { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return ceil_div(get<1>(problem_size), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + return; + } +}; + +struct ResidualMask : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) >= get<1>(problem_size)) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct ResidualMaskForBackward : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (! elem_less(pos, select<0,1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +// There are two ways to do causal if N_Q != N_K +// (1) The Q is at the beginning of the matrix +// (2) The Q is at the end of the matrix +template +struct CausalMask : NoMask { + + using Base = NoMask; + + static constexpr bool IsQBegin = kIsQBegin; + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // See note below on different ways to think about causal attention + // Again, we'd add the offset_q into the max_blocks_q calculation + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + if constexpr (IsQBegin) { + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } else { + const int offset_q = get<1>(problem_size) - get<0>(problem_size); + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + if constexpr (IsQBegin) { + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); + } else { + const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape); + return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); + } + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is the default setting. + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to set kIsQBegin=false + + if constexpr (IsQBegin) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } else { + const auto offset_q = get<1>(problem_size) - get<0>(problem_size); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } + } +}; + +template +struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { + + using Base = CausalMask; + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + int offset_q = 0; + if constexpr (!kIsQBegin) { + offset_q = get<1>(problem_size) - get<0>(problem_size); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size); + if (masked) { + acc_qk(i) = -INFINITY; + } + } + } + +}; + +struct VariableLength { + int max_length; + int* cumulative_length = nullptr; + int total_length = -1; + + CUTE_HOST_DEVICE operator int() const { + return max_length; + } +}; + +template struct is_variable_length_impl : std::false_type {}; +template<> struct is_variable_length_impl : std::true_type {}; +template constexpr bool is_variable_length_v = is_variable_length_impl>::value; + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Idx const& idx) { + return transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { + auto new_shape = apply_variable_length(shape, idx); + auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { + if constexpr (is_variable_length_v) { + return cute::make_tuple(c, s.cumulative_length[idx]); + } + else { + return c; + } + }); + return cute::make_tuple(new_shape, new_coord); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length_offset(Shape const& shape, Coord const& coord) { + auto idx = back(back(coord)); + auto result_shape = transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); + auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx]; + } + else { + return _0{}; + } + }); + return cute::make_tuple(result_shape, result_offset); +} + +} // namespace cutlass::fmha::collective + +namespace cute { + +template<> +struct is_integral : true_type {}; + +CUTE_HOST_DEVICE +void print(cutlass::fmha::collective::VariableLength a) { + printf("Varlen<%d, %p>", a.max_length, a.cumulative_length); +} + +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..616357cb0ef3ef4a585be39825ea1f34807e561f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::fmha::collective { + +template< + class Element, + class ElementAcc, + class TileShape, // Q, D, _ + class StrideO, // Q, D, B + class StrideLSE_, // Q, B + class OrderLoadEpilogue = cute::false_type +> +struct Sm100FmhaFwdEpilogueTmaWarpspecialized { + + using Pipeline = cutlass::PipelineAsync<2>; + +// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{}))); + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>()); +// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{})); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); + using SmemLayoutO_ = SmemLayoutO; + using StrideLSE = StrideLSE_; + using ElementOut = Element; + + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct TensorStorage { + + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + + }; + + struct Arguments { + Element* ptr_O; + StrideO dO; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + using TMA_O = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}), + SmemLayoutO{}(_,_,_0{}) + )); + + + struct Params { + TMA_O tma_store_o; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + // FMHA and MLA have different input ProblemShapes; + // get problem_shape_O according to the input ProblemShape. + template + CUTLASS_DEVICE static constexpr + auto get_problem_shape_O ( + ProblemShape const& problem_shape) { + if constexpr (rank_v(ProblemShape{}))> == 2) { + return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape)); + } else { + return select<0,2,3>(problem_shape); + } + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace = nullptr) { + + auto ptr_O = args.ptr_O; + StrideO dO = args.dO; + + auto problem_shape_O = get_problem_shape_O(problem_shape); + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dO) = get<0>(dO); + get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O)); + // offset ptr by the amount we add back in later + ptr_O -= max_length_q * get<0>(dO); + } + } + + auto tma_store_o = make_tma_copy( + SM90_TMA_STORE{}, + make_tensor(ptr_O, problem_shape_O, dO), + SmemLayoutO{}(_,_,_0{}) + ); + + return { + tma_store_o, + args.ptr_LSE, + args.dLSE + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + const Params& params; + + CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {} + + template + CUTLASS_DEVICE auto + store( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, + Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) { + + BlkCoord blk_coord = blk_coord_in; + uint32_t lane_predicate = cute::elect_one_sync(); + + using X = Underscore; + + int o0_index = 2 * get<0>(blk_coord); + int o1_index = 2 * get<0>(blk_coord) + 1; + + Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape)); + // offset mode 0 by (max_length - real_length) + // offset mode 3,1 by cumulative_length + real_length + // the ptr is already offset by - max_length + // so in total this achieves + int offs_0 = 0; + int offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + offs_0 = max_length_q - get<0>(problem_shape); + offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape); + get<2,1>(blk_coord) = 0; + } + } + + Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p); + + Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{}); + Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord)); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto block_tma = params.tma_store_o.get_slice(0); + Tensor tOsO = block_tma.partition_S(sO); + Tensor tOgO = block_tma.partition_D(gO); + + auto pipeline_release_state = pipeline_consumer_state; + + // O1 O2 + // one pipeline: O + // wait from corr, issue tma store on smem + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index)); + } + tma_store_arrive(); + + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index)); + } + tma_store_arrive(); + + tma_store_wait<1>(); + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + tma_store_wait<0>(); + + if constexpr (cute::is_same_v) { + cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1e094bf42dd7235f69f9db5f8e5e569d313c7ce6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1218 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_load_tma_warpspecialized.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class TileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1>, + // Since shared memory is sufficient for FMHA, there is no need to reuse shared memory. + class OrderLoadEpilogue = cute::false_type +> +struct Sm100FmhaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using TileShape = TileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountKV = sizeof(Element_) == 1 ? 4 : 3; + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + // Reuse shared memory for V and O. + static constexpr bool IsOrderLoadEpilogue = std::is_same_v; + struct TensorStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaUmmaAsync< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + + static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size"); + + using Load = Sm100FmhaLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction_empty( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + float lse = -INFINITY; + int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); + +#define DSHOW(x) print(#x ": "); print(x); print("\n") + if (threadIdx.x % 128 == 0 && block0()) { + DSHOW(sO); + } +#if 1 + + using ElementOut = typename CollectiveEpilogue::ElementOut; + auto tiled_copy = make_cotiled_copy( + Copy_Atom, ElementOut>{}, + make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), + sO.layout()); + + auto thr_copy = tiled_copy.get_slice(thread_idx); + auto tOgO = thr_copy.partition_D(sO); + auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); + clear(tOrO); + + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); +#endif + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3d1e1e8be3b3738140b2d89a8f4f4bacdce34b77 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::collective { + +template< + class Element_, + class StrideO_ +> +struct Sm100FmhaGenEpilogueWarpspecialized { + + using Pipeline = cutlass::PipelineAsync<2>; + + using SmemLayoutO = Layout>; + using SmemLayoutO_ = SmemLayoutO; + using Element = Element_; + using StrideOOrig = StrideO_; + using StrideO = decltype(replace<0>(StrideOOrig{}, 0)); + + struct TensorStorage { + + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + + }; + + struct Arguments { + Element* ptr_o; + StrideO dO; + }; + + using Params = Arguments; + + const Params& params; + + CUTLASS_DEVICE Sm100FmhaGenEpilogueWarpspecialized(const Params& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace = nullptr) { + return args; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + /* no-op */ + } + + template + CUTLASS_DEVICE auto + store( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, + Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) { + /* no-op */ + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..afef02245f316cfb03624d9ab2f4ca92553372c9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -0,0 +1,1113 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_load_cpasync_warpspecialized.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class ElementOut_, + class TileShape_, + class StrideQ_, + class StrideNewK_, + class StrideNewV_, + class StrideK_, + class StrideV_, + class StrideO_, + class Mask_ = ResidualMask, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_1, _2, _1> +> +struct Sm100FmhaGenMainloopWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using ElementAcc = ElementPV_; + using ElementOut = ElementOut_; + using TileShape = TileShape_; + using StrideQOrig = StrideQ_; + using StrideQ = decltype(replace<0>(StrideQ_{}, 0)); + using StrideNewK = StrideNewK_; + using StrideNewV = StrideNewV_; + using StrideCacheK = StrideK_; + using StrideCacheV = StrideV_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using StrideOOrig = StrideO_; + using StrideO = decltype(replace<0>(StrideO_{}, 0)); + using Mask = Mask_; + + static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; + static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{}); + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + struct TensorStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineUmmaConsumerAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineUmmaConsumerAsync< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static_assert(cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v) == cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v), "K and V smem layouts must be of equal size"); + + using Load = Sm100FmhaLoadCpAsyncWarpspecialized< + Element, StrideQ, StrideNewK, StrideNewV, StrideCacheK, StrideCacheV, + TensorStorage, CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + PipelineQ, PipelineKV, TileShape, Mask + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1, + GTensor& gO, CTensor const& cO, Shape const& g_shape, + Epilogue const& epilogue) { + + using ElementOut = typename GTensor::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOgO = mma.get_slice(0).partition_C(gO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int{}))); + + Tensor tOtO0 = tOtO_i; + tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO_i; + tOtO1.data() = tOtO1.data().get() + uint32_t(TmemAllocation::O1); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO0 = thr_tmem_load.partition_S(tOtO0); + Tensor tTMEM_LOADtO1 = thr_tmem_load.partition_S(tOtO1); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_LOADgO = thr_tmem_load.partition_D(tOgO_i); + + float row_max = std::max(v0(kIdxFinalRowMax), v1(kIdxFinalRowMax)); + float adj0 = ::exp2f(scale_softmax_log2 * (v0(kIdxFinalRowMax) - row_max)); + float adj1 = ::exp2f(scale_softmax_log2 * (v1(kIdxFinalRowMax) - row_max)); + float row_sum = adj0 * v0(kIdxFinalRowSum) + adj1 * v1(kIdxFinalRowSum); + float scale0 = scale_out * adj0 / row_sum; + float scale1 = scale_out * adj1 / row_sum; + + float2 scale0_f32x2 = make_float2(scale0, scale0); + float2 scale1_f32x2 = make_float2(scale1, scale1); + + // loop: + // TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0; + tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1; + tTMEM_LOADtO1_i.data() = tTMEM_LOADtO1_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMEM_LOADgO_i = tTMEM_LOADgO; + tTMEM_LOADgO_i.data() = tTMEM_LOADgO_i.data().get() + i * kCorrectionTileSize * stride<1>(gO); + + Tensor tTMrO0 = make_tensor(shape(tTMEM_LOADcO)); + Tensor tTMrO1 = make_tensor(shape(tTMEM_LOADcO)); + + copy(tiled_tmem_load, tTMEM_LOADtO0_i, tTMrO0); + copy(tiled_tmem_load, tTMEM_LOADtO1_i, tTMrO1); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO0); j += 2) { + float2 in0 = make_float2(tTMrO0(j), tTMrO0(j+1)); + float2 in1 = make_float2(tTMrO1(j), tTMrO1(j+1)); + float2 out; + cute::mul(out, scale0_f32x2, in0); + cute::fma(out, scale1_f32x2, in1, out); + tTMrO0(j) = out.x; + tTMrO0(j+1) = out.y; + } + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO0); + + Tensor tCs = recast(tTMrO0); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMgO_i = recast(tTMEM_LOADgO_i); + Tensor tSMrO_i = recast(tSMrO); + + // could use masking do this right for smaller D + if (get<0>(tTMEM_LOADcO(_0{})) < get<0>(g_shape)) { + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMgO_i); + } + } + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int(TileShape{}) / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + Epilogue const& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS0 = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS0); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + Tensor tTMEM_LOADVrS1 = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS1); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + auto pipeline_o_release_state = pipeline_o_consumer_state; + pipeline_o.consumer_wait(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + pipeline_o.consumer_wait(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + auto g_shape = select<0,2>(problem_shape); + auto mO = make_tensor(make_gmem_ptr(epilogue.params.ptr_o), append<3>(select<0,1>(TileShapePV{}), get<3>(problem_shape)), epilogue.params.dO); + auto gO = mO(_, _, get<2>(blk_coord)); + + correction_epilogue(params.scale_softmax_log2, params.scale_output, tTMEM_LOADVrS0, tTMEM_LOADVrS1, gO, cO, g_shape, epilogue); + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_release_state); + ++pipeline_o_release_state; + + pipeline_o.consumer_release(pipeline_o_release_state); + ++pipeline_o_release_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..259f1b4766e501420442c6c393a5fda8e89004ed --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideNewK, + class StrideNewV, + class StrideCacheK, + class StrideCacheV, + class TensorStorage, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class PipelineQ, + class PipelineKV, + class TileShape, + class Mask +> +struct Sm100FmhaLoadCpAsyncWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + struct Arguments { + + const int* cache_batch_idx; + + const Element* ptr_q; + StrideQ dQ; + + const Element* ptr_new_k; + StrideNewK dNewK; + const Element* ptr_new_v; + StrideNewV dNewV; + + Element* ptr_cache_k; + StrideCacheK dCacheK; + Element* ptr_cache_v; + StrideCacheV dCacheV; + }; + + using Params = Arguments; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + return args; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + } + + template + CUTLASS_DEVICE auto constexpr transpose(Tensor const& t) { + CUTE_STATIC_ASSERT_V(rank(t) == _2{}); + return t.compose(make_layout(make_shape(size<1>(t), size<0>(t)), make_stride(size<0>(t), _1{}))); + } + + template< + class CAtom, class TA, class TB, + class CountTensor, class CountLimit, + class SrcTensor, class DstTensor + > + CUTLASS_DEVICE void copy_with_limit( + TiledCopy const& tiled_copy, + CountTensor const& c, CountLimit const& l, + SrcTensor const& src, DstTensor&& dst) { + + //copy(tiled_copy, src, dst); +#if 1 + auto c_f = make_tensor(c.data(), flatten(c.layout())); + auto src_f = make_tensor(src.data(), flatten(src.layout())); + auto dst_f = make_tensor(dst.data(), flatten(dst.layout())); + auto c_v = group_modes<1,rank_v>(c_f); + auto src_v = group_modes<1,rank_v>(src_f); + auto dst_v = group_modes<1,rank_v>(dst_f); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(src_v); i++) { + if (elem_less(c_v(_0{}, i), l)) { + copy(CAtom{}, src_v(_, i), dst_v(_, i)); + } + else { + clear(dst_v(_, i)); + } + } +#endif + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + mask_tile_count *= 2; + + int warp_idx = (threadIdx.x / 32) % 2; + int thread_idx = warp_idx * 32 + (threadIdx.x % 32); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + auto blk_coord_cache = blk_coord; + if (params.cache_batch_idx != nullptr) { + get<2,1>(blk_coord_cache) = params.cache_batch_idx[get<2,1>(blk_coord_cache)]; + } + + // Q1, K1, K2, V1, K3, V2, ... Kn, Vn-1, Vn + // two pipes: Q and KV + auto cQ = make_identity_tensor(select<0,2>(TileShape{})); + auto mQ = make_tensor(make_gmem_ptr(params.ptr_q), append<3>(select<0,2>(TileShapeQK{}), get<3>(problem_shape)), params.dQ); + auto gQ = mQ(_, _, get<2>(blk_coord)); + auto sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + auto tSgQ = thr_mma_qk.partition_A(gQ); + auto tScQ = thr_mma_qk.partition_A(cQ); + + auto atom_q_tv = Layout, _16>, Stride, _1>>{}; + auto atom_kv_tv = Layout, _16>, Stride, _1>>{}; + + auto tiled_copy_q = make_cotiled_copy( + Copy_Atom, Element>{}, + atom_q_tv, + make_layout(shape(tSgQ), replace<0>(stride(tSgQ), replace<0>(stride<0>(tSgQ), get<2>(TileShape{}))))); + + auto thr_copy_q = tiled_copy_q.get_slice(thread_idx); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQ = thr_copy_q.partition_S(tSgQ); + auto tQcQ = thr_copy_q.partition_S(tScQ); + + auto limitQ = append<2>(get<0>(problem_shape), _128{}); + + // Q1 + int q0_index = get<0>(blk_coord); + + auto load_q = [&](int q_index, auto& state) { + pipeline_q.producer_acquire(state); + + // q is always loaded masked + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tQgQ(_, _, _, _)); + auto dst = recast(tQsQ(_, _, _, _, state.index())); + auto c = tQcQ(_, _, _, _); + int vlen = sizeof(Vec) / sizeof(Element); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen*i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitQ); + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Always>( + dst_ptr, src_ptr, guard + ); + } + + pipeline_q.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + }; + + load_q(q0_index, pipeline_q_producer_state); + ++pipeline_q_producer_state; + + auto cK_t = make_identity_tensor(select<1,2>(TileShapeQK{})); + auto cK = make_tensor(cK_t.data(), make_layout(get<0>(cK_t.layout()), get<1>(cK_t.layout()), make_layout(_2{}, get<1>(TileShapeQK{}) * stride<0>(cK_t)))); + auto mK = make_tensor(make_gmem_ptr(params.ptr_cache_k), select<1,2,3>(problem_shape), params.dCacheK); + auto gK = local_tile(mK(_, _, get<2>(blk_coord_cache)), TileShapeQK{}, make_coord(_, _, _0{}), Step{}); + auto sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + + auto tSgK = thr_mma_qk.partition_B(gK); + auto tScK = thr_mma_qk.partition_B(cK); + + auto tSlK = thr_mma_qk.partition_B(make_tensor((Element*) nullptr, make_ordered_layout(select<1,2>(TileShapeQK{}), Step<_1, _0>{}))); + auto tiled_copy_k = make_cotiled_copy( + Copy_Atom, Element>{}, + atom_kv_tv, + tSlK.layout()); + + auto thr_copy_k = tiled_copy_k.get_slice(thread_idx); + + auto tKsK = thr_copy_k.partition_D(sK); + auto tKgK = thr_copy_k.partition_S(tSgK); + auto tKcK = thr_copy_k.partition_S(tScK); + + int seqlen_cache_kv = get<1>(problem_shape) - ((params.ptr_new_k != nullptr) ? 1 : 0); + auto limitK = append<2>(seqlen_cache_kv, _128{}); + + auto cV_t = make_identity_tensor(select<1,2>(TileShapePV{})); + auto cV = make_tensor(cV_t.data(), make_layout(get<0>(cV_t.layout()), get<1>(cV_t.layout()), make_layout(_2{}, get<2>(TileShapePV{}) * stride<1>(cV_t)))); + auto mV = make_tensor(make_gmem_ptr(params.ptr_cache_v), select<2,1,3>(problem_shape), select<1,0,2>(params.dCacheV)); + auto gV = local_tile(mV(_, _, get<2>(blk_coord_cache)), TileShapePV{}, make_coord(_, _0{}, _), Step{}); + auto sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + typename CollectiveMmaPV::TiledMma mma_pv; + ThrMMA thr_mma_pv = mma_pv.get_slice(0); + auto tOgV = thr_mma_pv.partition_B(gV); + auto tOcV = thr_mma_pv.partition_B(cV); + auto tOlV = thr_mma_pv.partition_B(make_tensor((Element*) nullptr, make_layout(select<1,2>(TileShapePV{})))); + + auto tiled_copy_v = make_cotiled_copy( + Copy_Atom, Element>{}, + atom_kv_tv, + tOlV.layout()); + + auto thr_copy_v = tiled_copy_v.get_slice(thread_idx); + + auto tVsV = thr_copy_v.partition_D(sV); + auto tVgV = thr_copy_v.partition_S(tOgV); + auto tVcV = thr_copy_v.partition_S(tOcV); + + auto limitV = select<1,0>(limitK); + + int full_tiles_cache = seqlen_cache_kv / get<1>(TileShapeQK{}); + + bool has_new = params.ptr_new_k != nullptr; + Tensor mNewK = make_tensor(make_gmem_ptr(params.ptr_new_k), select<1,2,3>(problem_shape), params.dNewK); + Tensor mNewV = make_tensor(make_gmem_ptr(params.ptr_new_v), select<1,2,3>(problem_shape), params.dNewV); + Tensor gNewK = mNewK(_, _, get<2>(blk_coord)); + Tensor gNewV = mNewV(_, _, get<2>(blk_coord)); + + auto load_k = [&](int k_index, auto& state) { + pipeline_kv.producer_acquire(state); + + if (k_index < full_tiles_cache) { + copy(tiled_copy_k, tKgK(_, _, _, _, k_index), tKsK(_, _, _, _, state.index())); + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } else { + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tKgK(_, _, _, _, k_index)); + auto dst = recast(tKsK(_, _, _, _, state.index())); + auto src2 = recast(gNewK); + auto c = tKcK(_, _, _, _, k_index); + int vlen = sizeof(Vec) / sizeof(Element); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen*i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitK); + if (get<0>(cc) == seqlen_cache_kv && has_new) { + src_ptr = &src2(_0{}, get<1>(cc) / vlen); + guard = true; + } + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>( + dst_ptr, src_ptr, guard + ); + } + + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } + }; + + auto load_v = [&](int v_index, auto& state) { + pipeline_kv.producer_acquire(state); + + if (v_index < full_tiles_cache) { + copy(tiled_copy_v, tVgV(_, _, _, _, v_index), tVsV(_, _, _, _, state.index())); + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } else { + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tVgV(_, _, _, _, v_index)); + auto dst = recast(tVsV(_, _, _, _, state.index())); + auto src2 = recast(gNewV); + int vlen = sizeof(Vec) / sizeof(Element); + auto c = tVcV(_, _, _, _, v_index); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen*i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitV); + if (get<1>(cc) == seqlen_cache_kv && has_new) { + src_ptr = &src2(_0{}, get<0>(cc) / vlen); + guard = true; + } + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>( + dst_ptr, src_ptr, guard + ); + } + + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } + }; + + // K1 + int k_index = 0; + int v_index = 0; + + load_k(k_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + k_index += 1; + + mask_tile_count -= 1; + + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + load_k(k_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + k_index += 1; + + load_v(v_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + v_index += 1; + } + + // V1 + + load_v(v_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + v_index += 1; + + if (has_new) { + for (int i = thread_idx; i < get<2>(TileShape{}); i += 64) { + gK(seqlen_cache_kv, i, 0) = gNewK(0, i); + gV(i, seqlen_cache_kv, 0) = gNewV(0, i); + } + } + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3606dcc7d9a0e9b764a4f361daed83da3715d67a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -0,0 +1,299 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape +> +struct Sm100FmhaLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + + using IntProblemShape = cute::tuple, int>>; + + IntProblemShape problem_shape_qk; + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + auto cumulative_length_k = get<1>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { + get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; + get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; + get<2>(problem_shape_qk) = get<2>(problem_shape); + get<3>(problem_shape_qk) = get<3>(problem_shape); + } + } else { + problem_shape_qk = problem_shape; + } + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); + + int q_offs_0 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); + + int kv_offs_0 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // V1 + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bf41af9f533ad7e5099d5ae483c93605b44c71f5 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1225 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" +#include "common/pipeline_mla.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class ComposedTileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1>, + class OrderLoadEpilogue = cute::false_type +> +struct Sm100MlaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using ComposedTileShape = ComposedTileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountK = 1; + static constexpr int StageCountV = 1; + static constexpr int StageCountKV = StageCountK + StageCountV; + // Support StageCountKV > 2 in the future. + static_assert(StageCountK == 1 && StageCountV == 1, "Only support StageCountK = StageCountV = 1!"); + static_assert(std::is_same_v>, "Only support ThreadShape = Shape<_2, _1, _1>"); + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + static constexpr auto HeadDimLatent = size<2, 0>(ComposedTileShape{}); + static constexpr auto HeadDimRope = size<2, 1>(ComposedTileShape{}); + static constexpr auto HeadDimQK = HeadDimLatent + HeadDimRope; + static constexpr auto HeadDimPV = HeadDimLatent; + + using TileShapeQK = decltype(shape_div(replace<2>(ComposedTileShape{}, HeadDimQK), ThreadShape{})); + using TileShapePV = decltype(select<0,2,1>(shape_div(replace<2>(ComposedTileShape{}, HeadDimPV), ThreadShape{}))); + using TileShape = decltype(replace<2>(ComposedTileShape{}, HeadDimLatent)); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + using SmemStorageOneStageO = decltype(make_layout(replace<2>(TileShapePV{}, _1{}))); + + // Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory, + // we reuse shared memory for V and O to address this problem, + // and a barrier has been added to coordinate access to shared memory. + static constexpr bool IsOrderLoadEpilogue = std::is_same_v; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct TensorStorageQKVO { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_o; // use as O0 + cute::array_aligned> smem_v; // use as V0 and O1 + }; + + struct TensorStorageQKV { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + + using TensorStorage = std::conditional_t; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaAsyncMla< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static constexpr int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + using Load = Sm100MlaFwdLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape, OrderLoadEpilogue + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + bool need_apply_mask, + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_mask) { + if(need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + constexpr int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + const int mask_trip_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + const int total_trip_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + int trip_idx = total_trip_count; + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + constexpr bool NeedMask = !std::is_same_v; + + CUTLASS_PRAGMA_NO_UNROLL + for (; trip_idx > 0; trip_idx -= 1) { + softmax_step( + trip_idx <= mask_trip_count, + row_max, row_sum, stage, + trip_idx == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + constexpr int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + constexpr int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + constexpr int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction_empty( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + float lse = -INFINITY; + int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); + +#define DSHOW(x) print(#x ": "); print(x); print("\n") + if (threadIdx.x % 128 == 0 && block0()) { + DSHOW(sO); + } +#if 1 + + using ElementOut = typename CollectiveEpilogue::ElementOut; + auto tiled_copy = make_cotiled_copy( + Copy_Atom, ElementOut>{}, + make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), + sO.layout()); + + auto thr_copy = tiled_copy.get_slice(thread_idx); + auto tOgO = thr_copy.partition_D(sO); + auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); + clear(tOrO); + + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); +#endif + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c8fc13b95286fe18e94536204a304eefdabc67ca --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -0,0 +1,323 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape, + class OrderLoadEpilogue = cute::false_type +> +struct Sm100MlaFwdLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + + using IntProblemShape = cute::tuple, int>>; + + IntProblemShape problem_shape_qk; + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + auto cumulative_length_k = get<1>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { + get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; + get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; + get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape); + get<3>(problem_shape_qk) = get<3>(problem_shape); + } + } else { + problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));; + } + + auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape)); + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk)); + + int q_offs_0 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk)); + + int kv_offs_0 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + if constexpr (cute::is_same_v) { + cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + // V1 + pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); + + // prefetch vi + cute::prefetch(params.tma_load_v, tVgV(_, k_index)); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); + + // prefetch ki+1 + if(mask_tile_count > 1) { + cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1)); + } + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pipeline_mla.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pipeline_mla.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5bbeed9106b9c9f2631fe5c5ab635213295f92b7 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pipeline_mla.hpp @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Support the producer to acquire specific bytes of data. +*/ + +#pragma once + +#include "cutlass/pipeline/sm100_pipeline.hpp" + +namespace cutlass { + +using namespace cute; + +template < + int Stages_, + class ClusterShape = Shape, + class AtomThrShape_MNK_ = Shape<_1,_1,_1> +> +class PipelineTmaAsyncMla { + +public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; + +private: + using Impl = PipelineTmaUmmaAsync; + +public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + + using McastDirection = McastDirection; + + // Helper function to initialize barriers + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + auto atom_thr_shape = AtomThrShape_MNK{}; + uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { + auto atom_thr_shape = AtomThrShape_MNK{}; + + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? + cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas + cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate consumer mask + if (params_.role == ThreadCategory::Consumer) { + auto cluster_layout = make_layout(cluster_shape); + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { + // Calculate consumer mask + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + auto cluster_layout = make_layout(cluster_shape); + if (mcast_direction == McastDirection::kRow) { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + else { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + +public: + template + CUTLASS_DEVICE + PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + template + CUTLASS_DEVICE + PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape, mcast_direction); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape, mcast_direction); + } + } + + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + CUTLASS_DEVICE + void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); + if (barrier_token != BarrierStatus::WaitDone) { + empty_barrier_ptr_[stage].wait(phase); + } + + if (params_.is_leader) { + full_barrier_ptr_[stage].arrive_and_expect_tx(bytes); + } + #ifndef NDEBUG + if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) { + asm volatile ("brkpt;\n" ::); + } + + // Most likely you have elected more than one leader + if (params_.is_leader && (threadIdx.x % 32 != 0)) { + asm volatile ("brkpt;\n" ::); + } + #endif + } + + CUTLASS_DEVICE + void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token); + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.consumer_wait(state, barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index(), false); + } + +private: + Impl impl_; + Params params_; + EmptyBarrier *empty_barrier_ptr_; + FullBarrier *full_barrier_ptr_; + uint16_t block_id_mask_ = 0; + static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip) { + detail::pipeline_check_is_consumer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); + if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1 + if (!skip) { + cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_); + } + } + else { + if (!skip) { + if constexpr (cute::is_static_v and size(ClusterShape{}) == 1) { + cutlass::arch::umma_arrive(smem_ptr); + } + else { + cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); + } + } + } + } +}; + +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pow_2.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pow_2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eca93250f4c0cfdafac485ce5857f10469b3540d --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pow_2.hpp @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace cutlass::fmha { + +struct Pow2 { + int n; + int log2_n; + + explicit CUTE_DEVICE Pow2(int n) : n(n) { +#ifdef __CUDA_ARCH__ + log2_n = __ffs(n) - 1; +#endif + } + + template + CUTE_HOST_DEVICE T operator *(T const& b) const { + return n * b; + } + + template + CUTE_HOST_DEVICE auto operator *(Int const&) const { + if constexpr (N & (N - 1) == 0) { + return Pow2{n * N}; + } + return n * N; + } + +}; + +template +CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) { + return a >> b.log2_n; +} + +template +CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) { + return a & (b.n - 1); +} + +template +CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) { + return a < b.n; +} + +CUTE_HOST_DEVICE void print(Pow2 const& a) { + printf("2^%d", a.log2_n); +} + +} // end namespace cutlass::fmha + +namespace cute { + +template <> +struct is_integral : true_type {}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8406d3eb14de7e51aa069332fece686fbf4fa80 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha.hpp @@ -0,0 +1,276 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class FMHA { +public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return Kernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e4efb34ec54bfb7183b6f6f1fdd80ea21ac7e28 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp @@ -0,0 +1,375 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/tensor.hpp" + +#include "../device/fmha.hpp" +#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp" +#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" +#include "../kernel/fmha_kernel_bwd_convert.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class Element, + class ElementAccumulator, + class TileShape, + bool IsMla, + class Mask +> +class Sm100FmhaBwd { +private: + template + constexpr static auto to_bwd_shape(T shape) { + if constexpr (IsMla) { // remove GQA mode + constexpr int R = decltype(rank(shape))::value; + auto HB = get(shape); + auto rest = take<0,R-1>(shape); + return append(rest, make_shape(size<0>(HB), get<1>(HB))); + } + else { + return shape; + } + } + + template + constexpr static auto to_bwd_stride(T stride) { + if constexpr (IsMla) { // remove GQA mode + constexpr int R = decltype(rank(stride))::value; + auto HB = get(stride); + auto rest = take<0,R-1>(stride); + if constexpr (is_same_v(HB))>, _0>) { + return append(rest, make_stride(get<0,1>(HB), get<1>(HB))); + } + else { + return append(rest, make_stride(get<0,0>(HB), get<1>(HB))); + } + } + else { + return stride; + } + } + +public: + /// Argument structure: User API + struct Arguments { + // Q K D D_VO HB + ProblemShape problem_shape; + + const Element* ptr_Q; + cute::tuple, int>> stride_Q; + const Element* ptr_K; + cute::tuple, int>> stride_K; + const Element* ptr_V; + cute::tuple, int>> stride_V; + + const Element* ptr_O; + cute::tuple, int>> stride_O; + const ElementAccumulator* ptr_LSE; + cute::tuple, int>> stride_LSE; + + const Element* ptr_dO; + cute::tuple, int>> stride_dO; + + Element* ptr_dQ; + cute::tuple, int>> stride_dQ; + Element* ptr_dK; + cute::tuple, int>> stride_dK; + Element* ptr_dV; + cute::tuple, int>> stride_dV; + + ElementAccumulator softmax_scale; + + cutlass::KernelHardwareInfo hw_info; + }; + + using OperationSumOdO = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdSumOdO + >; + using OperationConvert = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdConvert + >; + + using OperationNormal= cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized< + ProblemShape, Element, ElementAccumulator, TileShape, Mask + > + >; + + using ProblemShapeMLA = decltype(to_bwd_shape(ProblemShape{})); + using OperationMla = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized< + ProblemShapeMLA, Element, ElementAccumulator, TileShape, Mask + > + >; + + using Operation = std::conditional_t; + + using Kernel = typename Operation::Kernel; + + struct Params { + OperationSumOdO op_sum_OdO; + Operation op; + OperationConvert op_convert; + ElementAccumulator* dQ_acc; + size_t dQ_acc_size; + }; + +private: + Params params_; + + static typename OperationSumOdO::Arguments to_sum_OdO_arguments( + Arguments const& args, + ElementAccumulator* sum_odo = nullptr, + ElementAccumulator* scaled_lse = nullptr) { + using namespace cute; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + auto [H_R, H_K] = H; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + auto stride_sum_OdO = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K)); + auto stride_scaled_lse = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K)); + auto log2_e = log2f(expf(1.0f)); + return typename OperationSumOdO::Arguments { + args.problem_shape, + args.ptr_O, args.stride_O, + args.ptr_dO, args.stride_dO, + sum_odo, stride_sum_OdO, + args.ptr_LSE, args.stride_LSE, + scaled_lse, stride_scaled_lse, + -1.0f, -log2_e + }; + } + + static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { + using namespace cute; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + auto [H_R, H_K] = H; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + auto stride_src_dQ = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K)); + return typename OperationConvert::Arguments { + args.problem_shape, + src, stride_src_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV, + args.ptr_dQ, args.stride_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV, + args.softmax_scale + }; + } + + static typename Operation::Arguments to_bwd_arguments( + Arguments const& args, + ElementAccumulator* sum_OdO = nullptr, cute::tuple, int>> const& stride_sum_OdO = {}, + ElementAccumulator* scaled_lse = nullptr, cute::tuple, int>> const& stride_scaled_lse = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple, int>> const& stride_dQ = {}) { + + return typename Operation::Arguments{ + to_bwd_shape(args.problem_shape), + { args.ptr_Q, to_bwd_stride(args.stride_Q), + args.ptr_K, to_bwd_stride(args.stride_K), + args.ptr_V, to_bwd_stride(args.stride_V), + args.ptr_dO, to_bwd_stride(args.stride_dO), + scaled_lse, to_bwd_stride(stride_scaled_lse), + sum_OdO, to_bwd_stride(stride_sum_OdO), + dQ_acc, to_bwd_stride(stride_dQ), + args.softmax_scale }, + { args.ptr_dK, to_bwd_stride(args.stride_dK), + args.ptr_dV, to_bwd_stride(args.stride_dV) }, + args.hw_info + }; + } + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + Status status = Status::kSuccess; + + status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = OperationConvert::can_implement(to_convert_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = Operation::can_implement(to_bwd_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = product_each(HB); + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + size_t workspace_bytes = 0; + // OdO vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // scaled LSE vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + return workspace_bytes; + } + + /// Initializes state from arguments. + Status + initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" + << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); + + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = product_each(HB); + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); + params_.dQ_acc = dQ_acc; + params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); + auto args_convert = to_convert_arguments(args, dQ_acc); + params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); + params_.op_convert.initialize(args_convert, nullptr, stream); + auto args_bwd = to_bwd_arguments( + args, sum_OdO, args_sum_OdO.stride_sum_OdO, + scaled_lse, args_sum_OdO.stride_scaled_lse, + dQ_acc, args_convert.stride_src_dQ + ); + params_.op.initialize(args_bwd, nullptr, stream); + + return Status::kSuccess; + } + + /// Initializes state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = product_each(HB); + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + char* workspace_chr = reinterpret_cast(workspace); + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); + return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); + + Status result = Status::kSuccess; + result = params.op_sum_OdO.run(stream); + if (result != Status::kSuccess) { + return result; + } + + auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); + if (cuda_result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = params.op.run(stream); + if (result != Status::kSuccess) { + return result; + } + + result = params.op_convert.run(stream); + if (result != Status::kSuccess) { + return result; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/sm100_mla.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/sm100_mla.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4706d9fee5ab523e5c4f62cefd8ca0370178f253 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/sm100_mla.hpp @@ -0,0 +1,383 @@ +/*************************************************************************************************** + * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp" +#include "kernel/sm100_fmha_mla_reduction.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +using namespace cute; +using namespace cutlass::fmha::kernel; + + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class Kernel_ +> +class MLA { +public: + + using Kernel = Kernel_; + + using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel< + typename Kernel::ElementOut, + typename Kernel::ElementAcc, + typename Kernel::ElementAcc, + Kernel::TileShapeH::value, + Kernel::TileShapeL::value, + 256 /*Max split*/ + >; + + /// Argument structure: User API + using KernelArguments = typename Kernel::Arguments; + using ReductionArguments = typename ReductionKernel::Arguments; + + using Arguments = KernelArguments; + + /// Argument structure: Kernel API + using KernelParams = typename Kernel::Params; + using ReductionParams = typename ReductionKernel::Params; + struct Params { + KernelParams fmha_params; + ReductionParams reduction_params; + }; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + + static ReductionArguments to_reduction_args(Arguments const& args) { + auto [H, K, D, B] = args.problem_shape; + return ReductionArguments{ + nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse, + args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq, + args.ptr_split_kv, Kernel::TileShapeS::value + }; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + static void set_split_kv (KernelArguments& args) { + if (args.split_kv >= 1) return; + auto [H, K, D, B] = args.problem_shape; + int sm_count = args.hw_info.sm_count; + int max_splits = ceil_div(K, 128); + int sms_per_batch = max(1, sm_count / B); + int split_heur = min(max_splits, sms_per_batch); + int waves = ceil_div(B * split_heur, sm_count); + int k_waves = ceil_div(max_splits, split_heur); + int split_wave_aware = ceil_div(max_splits, k_waves); + if (args.is_fused_reduction && split_wave_aware > 1) { + args.split_kv = std::min(split_wave_aware, static_cast(sm_count/2)); + } else { + args.split_kv = split_wave_aware; + } + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (! Kernel::can_implement(args)) { + return Status::kInvalid; + } + if (! ReductionKernel::can_implement(to_reduction_args(args))) { + return Status::kInvalid; + } + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args)); + return workspace_bytes; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream); + if (status != Status::kSuccess) { + return status; + } + KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {kernel_params, reduction_params}; + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + // no dynamic smem is needed for reduction kernel + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + auto fmha_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {fmha_params, reduction_params}; + + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = Kernel::get_grid_shape(params.fmha_params); + auto [H, K, D, B] = params.fmha_params.problem_shape; + auto [D_latent, D_rope] = D; + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + if (params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) { + auto result = cudaMemsetAsync(params.fmha_params.epilogue.ptr_o, 0, sizeof(typename Kernel::ElementOut) * H * D_latent * B, stream); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaMemsetAsync() returned error: " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + auto total_bytes = H * B * (sizeof(int) + sizeof(typename Kernel::ElementLSE)) + 2 * B * sizeof(int); + uint8_t* ws = reinterpret_cast(params.fmha_params.epilogue.ptr_lse_exchange_buff); + result = cudaMemsetAsync(ws, 0, total_bytes, stream); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaMemsetAsync() returned error: " + << cudaGetErrorString(result)); + return Status::kErrorInternal;; + } + } + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms.fmha_params}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params.fmha_params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess != result or Status::kSuccess != launch_result) { + //return Status::kSuccess; + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + if (!params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) { + // launch reduction kernel + dim3 const block = ReductionKernel::get_block_shape(); + dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params); + device_kernel<<>>(params.reduction_params); + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + else { + return Status::kSuccess; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_causal_tile_scheduler.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_causal_tile_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..572e67f6967c1e0c0d675cbadf120b8cf6114707 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_causal_tile_scheduler.hpp @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +// Swizzle Q tile and H tile to improve L2 cache hit rate, +// and launch the longest main loop first to keep most SMs busy. + +struct CausalIndividualTileScheduler { + + static constexpr int TileQ = 16; + static constexpr int TileH = 8; + static constexpr int TileSize = TileQ * TileH; + + struct Params { + dim3 grid; + int tile_max_q; + FastDivmod divmod_tile_col; + FastDivmod divmod_tile_size; + FastDivmod divmod_tile_head; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + CausalIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + + dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size)); + // gridDim.x must multiple of TileH + const int tile_col_count = grid.x / TileH; + const int tile_max_q = grid.y / TileQ * TileQ; + return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH}; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + const int block_idx = blockIdx.y * gridDim.x + blockIdx.x; + + int tile_idx, tile_tail; + params.divmod_tile_size(tile_idx, tile_tail, block_idx); + + int tile_row_idx, tile_col_idx; + params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx); + + int row_offset_in_tail, col_offset_in_tail; + params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail); + + const int row_idx = tile_row_idx * TileQ + row_offset_in_tail; + const int col_idx = tile_col_idx * TileH + col_offset_in_tail; + + // last q tile launch first + if(blockIdx.y >= params.tile_max_q) { + return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z))); + } + + return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z))); + } + + CUTLASS_DEVICE + CausalIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +// Launch order: H Q B +struct CausalPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_h; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_h(block_decode, bidh, block_decode); + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); + } + + CUTLASS_DEVICE + CausalPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7bee3d5f77c6518a13201a310a6adffeebc25432 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -0,0 +1,153 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdConvert { + + struct Arguments { + ProblemShape problem_shape; + + const ElementAcc* ptr_src_dQ; + tuple, int>> stride_src_dQ; + const ElementAcc* ptr_src_dK; + tuple, int>> stride_src_dK; + const ElementAcc* ptr_src_dV; + tuple, int>> stride_src_dV; + + Element* ptr_dest_dQ; + tuple, int>> stride_dest_dQ; + Element* ptr_dest_dK; + tuple, int>> stride_dest_dK; + Element* ptr_dest_dV; + tuple, int>> stride_dest_dV; + + ElementAcc scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static const int kBlockSeq = 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kNumThreadsD = 16; + static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 4; + + static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsSeq, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) { + auto ptr_src_bh = ptr_src + get<2,0,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<2,0,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; + + int seqlen = count; + if constexpr (is_variable_length_v) { + int offset = count.cumulative_length[blockIdx.y]; + ptr_dest_bh += offset * get<0>(stride_dest); + seqlen = count.cumulative_length[blockIdx.y + 1] - offset; + } + + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { + int idx_s = idx_s_t + kBlockSeq * blockIdx.z; + if (idx_s >= seqlen) continue; + auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); + auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) { + ElementAcc value_src[kElementsPerLoad]; + Element value_dest[kElementsPerLoad]; + + using VecSrc = uint_bit_t * kElementsPerLoad>; + using VecDest = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + value_dest[v] = static_cast(params.scale * value_src[v]); + } + + *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); + } + } + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + if (params.ptr_src_dQ != nullptr) { + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); + } + if (params.ptr_src_dK != nullptr) { + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape)); + } + if (params.ptr_src_dV != nullptr) { + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape)); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp new file mode 100644 index 0000000000000000000000000000000000000000..66780f353da60bd64347f06f30526873a72a21f2 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdSumOdO { + + struct Arguments { + ProblemShape problem_shape; + + const Element* ptr_O; + cute::tuple, int>> stride_O; + const Element* ptr_dO; + cute::tuple, int>> stride_dO; + + ElementAcc* ptr_sum_OdO; + cute::tuple, int>> stride_sum_OdO; + + const ElementAcc* ptr_lse = nullptr; + cute::tuple, int>> stride_lse; + + ElementAcc* ptr_scaled_lse = nullptr; + cute::tuple, int>> stride_scaled_lse; + + ElementAcc sum_odo_scale = 1.0; + ElementAcc lse_scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kBlockQ = 16; + + static const int kNumThreadsD = 8; + static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 2; + + static const int kIterationsQ = kBlockQ / kNumThreadsQ; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsQ, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); + auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); + auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); + + auto problem_q = get<0>(params.problem_shape); + int seqlen_q = problem_q; + if constexpr (is_variable_length_v) { + int offset = problem_q.cumulative_length[blockIdx.z]; + ptr_O_bh += offset * get<0>(params.stride_O); + ptr_dO_bh += offset * get<0>(params.stride_dO); + ptr_lse_bh += offset * get<0>(params.stride_lse); + seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset; + } + + CUTLASS_PRAGMA_UNROLL + for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { + int idx_q = idx_q_t + kBlockQ * blockIdx.x; + if (idx_q >= seqlen_q) continue; + ElementAcc acc = 0; + auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O); + auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO); + auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO); + auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); + auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { + Element value_O[kElementsPerLoad]; + Element value_dO[kElementsPerLoad]; + + using Vec = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); + *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + acc += value_O[v] * value_dO[v]; + } + } + + for (int i = 1; i < kNumThreadsD; i *= 2) { + acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); + } + + if (threadIdx.x == 0) { + *ptr_sum_OdO_bhq = params.sum_odo_scale * acc; + if (params.ptr_scaled_lse) { + *ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq; + } + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_options.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_options.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d4faa8d2151497fe1151ba879111c43cd1e4ccc7 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_options.hpp @@ -0,0 +1,85 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::kernel { + +template +struct find_option; + +template +struct find_option { + using option_value = Default; +}; + +template +struct find_option : + std::conditional_t< + Option::tag == kTag, + Option, + find_option + > +{}; + +template +using find_option_t = typename find_option::option_value; + +enum class Tag { + kIsPersistent, + kNumMmaWarpGroups, + kLoadsQSeparately, + + kIsMainloopLocked, + kIsEpilogueLocked, + + kStagesQ, + kStagesKV, + + kEpilogueKind, + + kBlocksPerSM, + kClusterM, + + kAccQK +}; + +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..119f069ccee5eb757edbf368e1cc2e66b3e5253c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct IndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + IndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size)); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + } + + CUTLASS_DEVICE + IndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct PersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_h; + FastDivmod divmod_b; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_h(block_decode, bidh, block_decode); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); + } + + CUTLASS_DEVICE + PersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..742b507d4deb31ba9a9ae8b7bef3cc53b076014a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1859 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeK = decltype(get<1>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<2>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{}); + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kS + TileShapeQ{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride, int>>; + using TensorStrideContiguousMN = Stride<_1, int, Stride, int>>; + using TensorStrideContiguousK_GQA = Stride, int>>; + using TensorStrideContiguousMN_GQA = Stride<_1, int, Stride, int>>; + + // compute S + using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK_GQA, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeKQ = typename CollectiveMmaKQ::TileShape; + using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma; + + // compute dP + using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK_GQA, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeVDO = typename CollectiveMmaVDO::TileShape; + using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{})); + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN_GQA, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using TensorStride_GQA = TensorStrideContiguousK_GQA; + using RowTensorStride = Stride<_1, Stride, int>>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride_GQA stride_k; + const Element* ptr_v; + TensorStride_GQA stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaKQ::Params::TMA_A; + using TMA_V = typename CollectiveMmaVDO::Params::TMA_A; + using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B; + using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(make_shape(1,1), 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride_GQA stride_dk; + Element* ptr_dv; + TensorStride_GQA stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + auto [H_R, H_K] = H; + if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H_R <= 0 || H_K <= 0 || B <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaKQ::to_underlying_arguments( + make_shape(K, Q, D, HB), + typename CollectiveMmaKQ::Arguments { + args.mainloop.ptr_k, args.mainloop.stride_k, + args.mainloop.ptr_q, args.mainloop.stride_q, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaVDO::to_underlying_arguments( + make_shape(K, Q, D_VO, HB), + typename CollectiveMmaVDO::Arguments { + args.mainloop.ptr_v, args.mainloop.stride_v, + args.mainloop.ptr_do, args.mainloop.stride_do, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_a, + params_vdo.tma_load_a, + params_kq.tma_load_b, + params_vdo.tma_load_b, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_start, + int iter_end, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + int iter_index = iter_start; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); + auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_A(gK); + auto tSTgQ = cta_mma_kq.partition_B(gQ); + auto tDPTgV = cta_mma_vdo.partition_A(gV); + auto tDPTgDO = cta_mma_vdo.partition_B(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading 128 values of 32b each + // so 4*32b=128b + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + if (iter_index == iter_end) { + iter_index = iter_start; + get<0,0>(blk_coord_batch) += 1; + } + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_start, + int iter_end, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK); + Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ); + + Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV); + Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + tDVrP.data() = TmemAllocation::kP; + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaKQ tiled_mma_kq; + TiledMmaVDO tiled_mma_vdo; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + + Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + Tensor tCr = thr_copy.partition_S(quantized_regs); + Tensor tCg = thr_copy.partition_D(gmem); + Tensor tPc = thr_copy.partition_D(preds); + + copy_if(copy_op, tPc, tCr, tCg); + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_start, + int iter_end, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + int iter_index = iter_start; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + auto store_op = []() { + if constexpr (sizeof(Element) == 1) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else { + return SM100_TMEM_STORE_32dp32b8x{}; + } + }(); + + Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + + } + }; + + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + // Tensor tTR_tST_p = thread_t2r.partition_S(tSTtST); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + + auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST); + tDVrP.data() = TmemAllocation::kP; + + auto tiled_r2t = make_tmem_copy(store_op, tDVrP); + auto thread_r2t = tiled_r2t.get_slice(dp_idx); + + auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP)); + auto tRT_cST_p = thread_r2t.partition_S(tDVcST); + auto tRT_cST = split_wg(tRT_cST_p); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} > get<1>(problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == iter_start); + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + int kv_left = get<1>(blk_coord) * TileShapeK{}; + int kv_right = kv_left + TileShapeK{} - 1; + int q_left = iter_index * TileShapeQ{} + offset; + int q_right = q_left + TileShapeQ{} - 1; + + leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == iter_end - 1) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); + }); + + // notify for P + cutlass::arch::fence_view_async_tmem_store(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_128{}, _128{}), + make_stride(_1{}, _0{}) + ); + + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + if (iter_index == iter_end) { + iter_index = iter_start; + } + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_start, + int iter_end, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + int iter_index = iter_start; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step{}) + (_, _, _, _0{}, _); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index,blk_coord_batch)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + if (iter_index == iter_end) { + iter_index = iter_start; + get<0,0>(blk_coord_batch) += 1; + } + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(make_coord(0, blockIdx.y), blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_end = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + int iter_count = (iter_end - iter_start) * get<4,0,0>(problem_shape); + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_end, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_end, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_end, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_end, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } +#endif + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + auto [H_R, H_K] = H; + dim3 grid(ceil_div(K, TileShapeK{}), H_K, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bf72843a43ecd36a98ef9b938d872700057e59a9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1826 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + using TileShapeK = decltype(get<1>(TileShape{})); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<3>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + 65536 * 16; + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kDQ + TileShapeDQK{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + + static constexpr int kLoadPerThread = TileShapeQ{} / NumThreadsPerWarp; + static_assert(TileShapeQ{} % NumThreadsPerWarp == 0, "TileShapeQ must be divisible by NumThreadsPerWarp"); + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + + // compute dP + using CollectiveMmaDOV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDOV = typename CollectiveMmaDOV::TileShape; + using TiledMmaDOV = typename CollectiveMmaDOV::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = typename CollectiveMmaPDO::TiledMma; + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaQK::SmemLayoutB{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaDOV::SmemLayoutB{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaQK::SmemLayoutA{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaDOV::SmemLayoutA{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + using SmemLayoutP = decltype(restage(typename CollectiveMmaPDO::SmemLayoutA{}, _1{})); + using SmemLayoutPT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + union{ + alignas(2048) cute::array> smem_p; + alignas(2048) cute::array> smem_p_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaDOV::Params::TMA_B; + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_DO = typename CollectiveMmaDOV::Params::TMA_A; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0 || D_VO <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaQK::to_underlying_arguments( + make_shape(Q, K, D, HB), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q, args.mainloop.stride_q, + args.mainloop.ptr_k, args.mainloop.stride_k, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaDOV::to_underlying_arguments( + make_shape(Q, K, D_VO, HB), + typename CollectiveMmaDOV::Arguments { + args.mainloop.ptr_do, args.mainloop.stride_do, + args.mainloop.ptr_v, args.mainloop.stride_v, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_b, + params_vdo.tma_load_b, + params_kq.tma_load_a, + params_vdo.tma_load_a, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gQ = local_tile(mQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gV = local_tile(mV, TileShapeDOV{}, make_coord(_,_,_), Step{}); + auto gDO = local_tile(mDO, TileShapeDOV{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_kq = TiledMmaQK{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaDOV{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_B(gK); + auto tSTgQ = cta_mma_kq.partition_A(gQ); + auto tDPTgV = cta_mma_vdo.partition_B(gV); + auto tDPTgDO = cta_mma_vdo.partition_A(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading kLoadPerThread * 32 values of 32b each + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr(shared_tensors.smem_p.begin()), SmemLayoutP{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaQK::make_fragment_B(sK); + Tensor tSTrQ = TiledMmaQK::make_fragment_A(sQ); + + Tensor tDPTrV = TiledMmaDOV::make_fragment_B(sV); + Tensor tDPTrDO = TiledMmaDOV::make_fragment_A(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP); + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaQK tiled_mma_qk; + TiledMmaDOV tiled_mma_dov; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_dov, select<0,1>(TileShapeDOV{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + + Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + Tensor tCr = thr_copy.partition_S(quantized_regs); + Tensor tCg = thr_copy.partition_D(gmem); + Tensor tPc = thr_copy.partition_D(preds); + + copy_if(copy_op, tPc, tCr, tCg); + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaPDO{}, select<0,1>(TileShapePDO{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_16dp32b32x{}; + + Tensor tSTtST = partition_fragment_C(TiledMmaQK{}, select<0,1>(TileShapeQK{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaDOV{}, select<0,1>(TileShapeDOV{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeDOV{})); + Tensor cPT = make_identity_tensor(take<0,2>(TileShapeQK{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + } + }; + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cPT_p = thread_t2r.partition_D(cPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); + int last_iter = iter_count - 1 + iter_index; + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + int kv_left = get<1>(blk_coord) * TileShapeK{}; + int kv_right = kv_left + TileShapeK{} - 1; + int q_left = iter_index * TileShapeQ{} + offset; + int q_right = q_left + TileShapeQ{} - 1; + + leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<0>(c_transpose) + iter_index * TileShapeQ{}, get<1>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<0>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<0>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}) + (_, _, _, pipeline_compute_mma_p_producer_state.index()); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + auto sP_pi = as_position_independent_swizzle_tensor(sP); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sP_pi_slice_p = sP_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape(tTR_cPT_p))); + auto sP_pi_slice = split_wg(sP_pi_slice_p); + copy_aligned(tRT_rST, sP_pi_slice); + }); + + // notify for P + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<0>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<0>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape (tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_16dp32b16x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + iter_count -= iter_start; + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } +#endif + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a88b9a871098f27cf569dc20434c1f35a103d9ff --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,640 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" + +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "kernel/fmha_causal_tile_scheduler.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/fmha_common.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + + +struct Sm100MlaFwdCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 184; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + +template< + class ProblemShapeIn, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule +> +struct Sm100FmhaFwdKernelTmaWarpspecialized { + + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = ProblemShapeIn; + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue); + static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad); + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + static constexpr bool IsMla = std::is_same_v; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + using UnionType = union { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + using StructType = struct { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + static constexpr bool IsPersistent = std::is_same_v || std::is_same_v; + using MainloopEpilogueStorage = std::conditional_t, + StructType>, + UnionType>; + + MainloopEpilogueStorage mainloop_epilogue; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct Arguments { + ProblemShape problem_shape; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + cutlass::KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { + return apply_variable_length(params.problem_shape, batch_idx); + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + auto get_epilogue_storage = [&]() { + if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { + return reinterpret_cast(shared_storage.mainloop_epilogue.mainloop.smem_o.data()); + } else { + return &shared_storage.mainloop_epilogue.epilogue; + } + }; + typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage(); + + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, + pipeline_load_q_params, + ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; + if (role == WarpRole::Load) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + } + pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK; + typename CollectiveMainloop::PipelineKV pipeline_load_kv( + shared_storage.pipelines.load_kv, + pipeline_load_kv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, + pipeline_mma_s0_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, + pipeline_mma_s1_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr( + shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr( + shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, + pipeline_mma_corr_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineE pipeline_corr_epi( + shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01( + shared_storage.pipelines.order_s01, params_order_s01); + + TmemAllocator tmem_allocator; + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_kv.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue{params.epilogue}; + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, + params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01 + ); + + } + } + else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + bool has_valid = false; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + has_valid = true; + + if (get<1>(logical_problem_shape) == 0) { + mainloop.correction_empty( + blk_coord, + params.mainloop, logical_problem_shape, + params.problem_shape, + epilogue_storage, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + continue; + } + + mainloop.correction( + blk_coord, + params.mainloop, logical_problem_shape, + params.problem_shape, + epilogue_storage, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, + pipeline_s1_corr, pipeline_s1_corr_consumer_state, + pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + if (has_valid) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + + } + else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + bool allocated = false; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (!allocated) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + allocated = true; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + mainloop.mma( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.mainloop_epilogue.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, + pipeline_load_kv, pipeline_load_kv_consumer_state, + pipeline_mma_s0, pipeline_mma_s0_producer_state, + pipeline_mma_s1, pipeline_mma_s1_producer_state, + pipeline_mma_corr, pipeline_mma_corr_producer_state + ); + + } + } + else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { + cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + mainloop.load( + blk_coord, logical_problem_shape, + params.mainloop, params.problem_shape, + shared_storage.mainloop_epilogue.mainloop, + pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_kv, pipeline_load_kv_producer_state + ); + + } + } + else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + bool has_valid = false; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + has_valid = true; + + epilogue.store( + blk_coord, logical_problem_shape, + params.epilogue, params.problem_shape, + epilogue_storage, + pipeline_corr_epi, pipeline_corr_epi_consumer_state + ); + + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + if(has_valid) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + + } + else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } +#endif + } + +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d59b20ff194c2eb6c0c18b2e0af03047d15601ea --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp @@ -0,0 +1,580 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" + +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaGenKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + if (warp_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (warp_idx == 1) return WarpRole::MMA; // 12 + if (warp_idx == 2 || warp_idx == 3) return WarpRole::Load; // 13 + if (warp_idx == 4) return WarpRole::Softmax1; // 4 - 7 + if (warp_idx == 8) return WarpRole::Correction; // 8 - 11 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 1; + static const int NumWarpsCorrection = 1; + static const int NumWarpsEpilogue = 0; + static const int NumWarpsLoad = 2; + + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 104; + static const int NumRegsOther = 248; + static const int NumRegsEmpty = 24; + + static const int NumWarps = 12; + +}; + +template< + class ProblemShapeIn, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class KernelSchedule = Sm100FmhaGenKernelWarpspecializedSchedule +> +struct Sm100FmhaGenKernelWarpspecialized { + + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = decltype(replace<0>(ProblemShapeIn{}, 0)); + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using StrideQOrig = typename CollectiveMainloop::StrideQOrig; + using StrideOOrig = typename CollectiveMainloop::StrideOOrig; + using StrideQ = typename CollectiveMainloop::StrideQ; + using StrideO = typename CollectiveMainloop::StrideO; + using StrideCacheK = typename CollectiveMainloop::StrideCacheK; + using StrideCacheV = typename CollectiveMainloop::StrideCacheV; + using StrideNewK = typename CollectiveMainloop::StrideNewK; + using StrideNewV = typename CollectiveMainloop::StrideNewV; + using Element = typename CollectiveMainloop::Element; + using ElementAcc = typename CollectiveMainloop::ElementAcc; + using ElementOut = typename CollectiveMainloop::ElementOut; + + struct Arguments { + // _1, max_seqlen_k, head_dim, ((h_g, h_kv), b) + ProblemShapeIn problem_shape; + const int* seqlen_kv; + const int* cache_batch_idx; + + const Element* ptr_q; // 1 x D x (H x B) + StrideQOrig dQ; + const Element* ptr_new_k; // 1 x D x (H x B) + StrideNewK dNewK; + const Element* ptr_new_v; // 1 x D x (H x B) + StrideNewV dNewV; + + Element* ptr_cache_k; // seqlen_max x D x (H x B) + StrideCacheK dCacheK; + Element* ptr_cache_v; // seqlen_max x D x (H x B) + StrideCacheV dCacheV; + ElementOut* ptr_o; // 1 x D x (H x B) + StrideOOrig dO; + + cutlass::KernelHardwareInfo hw_info; + + ElementAcc scale_softmax = 0.0f; + }; + + struct Params { + ProblemShape problem_shape; + const int* seqlen_kv; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return true; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + ProblemShape problem_shape = replace<0>(args.problem_shape, static_cast(get<0>(args.problem_shape))); + CUTE_STATIC_ASSERT_V(get<0>(args.problem_shape) == _1{}); + StrideQ dQ = replace<0>(args.dQ, 0); + StrideO dO = replace<0>(args.dO, 0); + get<0>(problem_shape) = get<3,0,0>(args.problem_shape); + get<3,0,0>(problem_shape) = 1; + get<0>(dQ) = get<2,0,0>(dQ); + get<0>(dO) = get<2,0,0>(dO); + + typename CollectiveMainloop::Arguments mainloop_args { + { + args.cache_batch_idx, + args.ptr_q, dQ, + args.ptr_new_k, args.dNewK, + args.ptr_new_v, args.dNewV, + args.ptr_cache_k, args.dCacheK, + args.ptr_cache_v, args.dCacheV, + }, + args.scale_softmax + }; + + typename CollectiveEpilogue::Arguments epilogue_args { + args.ptr_o, dO, + }; + + return Params{ + problem_shape, + args.seqlen_kv, + CollectiveMainloop::to_underlying_arguments(problem_shape, mainloop_args, workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shape, epilogue_args, workspace), + TileScheduler::to_underlying_arguments(problem_shape, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { + ProblemShape result = problem_shape; + get<1>(result) = params.seqlen_kv[batch_idx]; + if (params.mainloop.load.ptr_new_k != nullptr) { + get<1>(result) += 1; + } + return result; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, + pipeline_load_q_params, + ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; + if (role == WarpRole::Load) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + } + pipeline_load_kv_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineKV pipeline_load_kv( + shared_storage.pipelines.load_kv, + pipeline_load_kv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, + pipeline_mma_s0_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, + pipeline_mma_s1_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr( + shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr( + shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, + pipeline_mma_corr_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = cute::max(1, NumWarpsEpilogue * cutlass::NumThreadsPerWarp); + typename CollectiveMainloop::PipelineE pipeline_corr_epi( + shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01( + shared_storage.pipelines.order_s01, params_order_s01); + + TmemAllocator tmem_allocator; + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_kv.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue(params.epilogue); + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, + params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01 + ); + + } + } + else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.correction( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.epilogue, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, + pipeline_s1_corr, pipeline_s1_corr_consumer_state, + pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + + + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + + mainloop.mma( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, + pipeline_load_kv, pipeline_load_kv_consumer_state, + pipeline_mma_s0, pipeline_mma_s0_producer_state, + pipeline_mma_s1, pipeline_mma_s1_producer_state, + pipeline_mma_corr, pipeline_mma_corr_producer_state + ); + + + } + } + else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.load( + blk_coord, logical_problem_shape, + params.mainloop, params.problem_shape, + shared_storage.mainloop, + pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_kv, pipeline_load_kv_producer_state + ); + + } + } + else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + epilogue.store( + blk_coord, logical_problem_shape, + params.epilogue, params.problem_shape, + shared_storage.epilogue, + pipeline_corr_epi, pipeline_corr_epi_consumer_state + ); + + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } +#endif + } + +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2418bcf804a6bb57fc3358a741c3b333f4d19cfb --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/arch.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +template< + class ElementOut, + class ElementAcc, + class ElementScale, + size_t kNumHeads, + size_t kHeadDimLatent, + int kMaxSplits +> +struct Sm100FmhaMlaReductionKernel { + + static const int SharedStorageSize = 0; + static const int MaxThreadsPerBlock = 128; + static const int MinBlocksPerMultiprocessor = 1; + + using ArchTag = cutlass::arch::Sm100; + + static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0); + struct Arguments { + ElementAcc* ptr_oaccum = nullptr; + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_lseaccum = nullptr; + ElementAcc* ptr_lse = nullptr; + ElementScale scale = 1.f; + int num_batches = 0; + int split_kv = -1; + int dim_k = -1; + int* ptr_seq = nullptr; + int* ptr_split_kv = nullptr; + int tile_shape_s = 128; + }; + using Params = Arguments; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse, + args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq, + args.ptr_split_kv, args.tile_shape_s}; + } + + static size_t get_workspace_size(Arguments const& /*args*/) { + return 0; + } + + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return dim3(kNumHeads, 1, params.num_batches); + } + + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + static bool can_implement(Arguments const& args) { + if (args.num_batches <= 0) return false; + if (args.split_kv <= 0) return false; + return true; + } + + CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) { + if (params.split_kv <= 1) return; + + auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z); + + auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)]; + auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)]; + auto k_tile_total = ceil_div(dim_k, params.tile_shape_s); + auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv); + local_split_kv = ceil_div(k_tile_total, k_tile_per_cta); + + if (local_split_kv == 1) return; + + __shared__ ElementAcc sLseScale[kMaxSplits]; + const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord); + const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord); + + Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum), + make_shape(params.split_kv), Stride>{}); + + Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse), + Shape<_1>{}, Stride<_1>{}); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + ElementAcc local_lse[kNLsePerThread]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits::infinity(); + } + + ElementAcc lse_max = -std::numeric_limits::infinity(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + lse_max = fmax(local_lse[i], lse_max); + } + + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + lse_max = fmax(__shfl_xor_sync(0xffffffff, lse_max, offset), lse_max); + } + + lse_max = __shfl_sync(0xffffffff, lse_max, 0); + + ElementAcc sum_lse = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + sum_lse = sum_lse + expf(local_lse[i] - lse_max); + } + + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset); + } + + sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); + + ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + lse_max; + if (threadIdx.x == 0 and params.ptr_lse != nullptr) { + gLSE(0) = global_lse; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + if (split < local_split_kv) { + sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + } + __syncthreads(); + + constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock; + const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord)); + Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum), + Shape>{}, Stride<_1>{}); + ElementAcc local_val[Elements] = {0}; + for (int split = 0; split < local_split_kv; ++split) { + ElementAcc lse_scale = sLseScale[split]; + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i); + } + gOaccum.data() = gOaccum.data() + kHeadDimLatent; + } + auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent; + Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape>{}, Stride<_1>{}); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast(local_val[i]); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e9edb90e571a58d7c1ec954e5f7c8bb11489fcf4 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -0,0 +1,2237 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cutlass/barrier.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "gather_tensor.hpp" // from examples/common +#include "common/pow_2.hpp" +#include "sm100_mla_tile_scheduler.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template< + class TileShape, + class Element_, + class ElementAcc_, + class ElementOut_, + class ElementLSE_, + class TileScheduler, +#ifdef CPASYNC + bool kIsCpAsync = true +#else + bool kIsCpAsync = false +#endif +> +struct Sm100FmhaMlaKernelTmaWarpspecialized { + + using Element = Element_; + using ElementAcc = ElementAcc_; + using ElementOut = ElementOut_; + using ElementLSE = ElementLSE_; + + // only 2Sm mode is supported + static const bool kIs2Sm = true; + static const int MaxThreadsPerBlock = 256; + static const int MinBlocksPerMultiprocessor = 1; + static const int TotalSNum = 2; + static const int TotalPNum = 2; + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = cute::conditional_t, Shape<_1, _1, _1>>; + + using TileShapeH = tuple_element_t<0, TileShape>; + using TileShapeS = tuple_element_t<1, TileShape>; + using TileShapeD = tuple_element_t<2, TileShape>; + + using TileShapeL = tuple_element_t<0, TileShapeD>; + using TileShapeR = tuple_element_t<1, TileShapeD>; + static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim"); + + using ProblemShape = Shape; + using TensorStride = Stride; + using TmemAllocator = cute::conditional_t; + + static_assert(TileShapeH{} == 128); + static const int kWarpsInN = kIs2Sm ? 2 : 1; + + static const int kNumComputeWarps = 4; + static const int kNumLoadWarps = kIsCpAsync ? 2 : 1; + + enum class WarpRole { + kMma = 0x1, kLoad = 0x2, kCompute = 0x3, kLoadPageTable = 0x4, kEmpty=0x0 + }; + + static const long long unsigned int kWarpAssignment = kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull; + + static CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + static const int Alignment = 128 / sizeof_bits_v; + static const int AlignmentOut = 128 / sizeof_bits_v; + + using TileShapeQK = Shape; + static const int StagesQK = 24 / sizeof(Element); // free parameter + static const int IterationsQKLatent = decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQKRope = decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQK = IterationsQKLatent + IterationsQKRope; + + using Schedule = cute::conditional_t; + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TensorStride, Alignment, + ElementAcc, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using CtaShapeQK = typename CollectiveMmaQK::CtaShape_MNK; + + // chosen for unified smem staging between K and V + using TileShapePV = Shape; + using TransposeTensorStride = decltype(select<1,0,2>(TensorStride{})); + static const int StagesPV = StagesQK; // not sure why, but must be at least two. check pipes + static const int IterationsPV_K = decltype(TileShapeS{} / get<2>(TileShapePV{}))::value; + static const int IterationsPV_N = decltype(TileShapeL{} / get<1>(TileShapePV{}))::value; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TransposeTensorStride, Alignment, + ElementAcc, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using CtaShapePV = typename CollectiveMmaPV::CtaShape_MNK; + static_assert(std::is_same_v); + + using TiledMmaPV = typename CollectiveMmaPV::TiledMma; + + using AtomThrShapeMNK = typename CollectiveMmaQK::AtomThrShapeMNK; + static_assert(typename CollectiveMmaQK::AtomThrShapeMNK{} == typename CollectiveMmaPV::AtomThrShapeMNK{}, "schedule must match"); + + static const int StagesPageTable = kIsCpAsync ? StagesPV : 1; + + // pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd + // use expect_tx for Q load + using PipelineLoadQK = cute::conditional_t, PipelineTmaUmmaAsync>; + using PipelineLoadPV = PipelineLoadQK; + // pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages + using PipelineS = PipelineUmmaAsync; + // pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages + using PipelineP = PipelineUmmaConsumerAsync; + // pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage + using PipelineO = PipelineUmmaAsync<1, AtomThrShapeMNK>; + + using PipelinePT = PipelineAsync; + + struct PipelineStorage { + alignas(16) typename PipelineLoadQK::SharedStorage load_qk; + alignas(16) typename PipelineS::SharedStorage mma_s; + alignas(16) typename PipelineP::SharedStorage p_mma; + alignas(16) typename PipelineO::SharedStorage mma_o; + alignas(16) typename PipelinePT::SharedStorage load_page_table; + }; + + template + static CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB; + using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int{}, _2{}))); + using SmemLayoutOut = decltype(take<0,2>(typename CollectiveMmaQK::CtaShape_MNK{})); + using TileShapeAcc = decltype(take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{})); + + static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v); + static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v); + + // pre-condition for overlapped smem staging + static_assert(kBytesLoadKC == kBytesLoadVC); + static_assert(StagesQK == StagesPV); + + static const int kTransactionsBytesLoadQK = kBytesLoadKC; + static const int kTransactionsBytesLoadExtraQ = kBytesLoadQ; + static const int kTransactionsBytesLoadPV = kBytesLoadVC; + + static const int kNamedBarrierExchange = (int) cutlass::arch::ReservedNamedBarriers::TransformBarrier; + // This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent + // tile scheduler for FP8 MLA. + static const int kNamedBarrierEpilogue = (int) cutlass::arch::ReservedNamedBarriers::EpilogueBarrier; + // + static const int kNamedBarrierTmemDealloc = (int) cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier; + + enum class TmemAllocation : uint32_t { + kSizeS = TileShapeS::value / kWarpsInN, + // Overall + kSizeO = TileShapeL::value / kWarpsInN, + // Between accumulators we loop over + kSizeAccO = decltype(get<1>(TileShapePV{}))::value / kWarpsInN, + kNumS = TotalSNum, + kNumP = TotalPNum, + kNumO = 1, + kS0 = 0, + kS1 = kS0 + kSizeS, + kO0 = kS1 + kSizeS, + kTotal = kO0 + kSizeO + }; + + static_assert(static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem"); + + struct TensorStorage { + // to communicate max and row_sum + cute::array smem_exchange; + cute::array smem_page_table; + alignas(2048) cute::array> smem_q; + union { + alignas(2048) cute::array> smem_kc; + alignas(2048) cute::array> smem_vc; + }; + union { + alignas(2048) cute::array> smem_p; + alignas(2048) cute::array smem_acc; + }; + }; + + struct SharedStorage { + PipelineStorage pipelines; + TensorStorage tensors; + uint32_t tmem_base_ptr; + }; + + static const int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + struct MainloopArguments { + ElementAcc softmax_scale; + + // all tensors strides are (num_heads or seqlen, head_dim, batch) + // head_dim stride is always 1 + Element* ptr_q_latent; + TensorStride stride_q_latent; + Element* ptr_q_rope; + TensorStride stride_q_rope; + + Element* ptr_c_latent; + TensorStride stride_c_latent; + Element* ptr_k_rope; + TensorStride stride_k_rope; + + // for paged attention, we interpret what was previously [batch, seqlen] + // as [page_count, page_size], and index according to page_table + int* ptr_seq = nullptr; + int* ptr_page_table = nullptr; + // page table is [batch, seqlen or similar] + Stride<_1, int> stride_page_table = {}; + int page_count = 0; + int page_size = TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS + }; + + struct EpilogueArguments { + ElementOut* ptr_o = nullptr; + TensorStride stride_o; + ElementLSE* ptr_lse = nullptr; + Stride<_1, int> stride_lse; + ElementAcc output_scale = 1.0f; + }; + + struct Arguments { + // (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count) + // for paged attention, seqlen is max seqlen + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + int split_kv = -1; + int* ptr_split_kv = nullptr; + bool is_fused_reduction = false; + }; + + using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadQRope = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadCLatent = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B; + + using GmemLayout = decltype(make_layout(Shape{}, Stride{})); + using SmemLayout = decltype(make_layout(TileShapeAcc{}, LayoutRight{})); + + using TmaReduceSum = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor(recast_ptr(nullptr), GmemLayout{}), SmemLayout{})); + + struct MainloopParams { + TmaLoadQLatent tma_load_q_latent; + TmaLoadQRope tma_load_q_rope; + TmaLoadCLatent tma_load_c_latent; + TmaLoadKRope tma_load_k_rope; + TmaLoadCLatentTranspose tma_load_c_latent_transpose; + }; + + struct EpilogueParams { + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_o_acc = nullptr; + TensorStride stride_o; + TensorStride stride_o_acc; + ElementLSE* ptr_lse = nullptr; + ElementLSE* ptr_lse_acc = nullptr; + Stride<_1, int> stride_lse; + Stride<_1, int> stride_lse_acc; + ElementAcc output_scale = 1.0f; + ElementLSE* ptr_lse_exchange_buff = nullptr; + int* ptr_lse_max_exchange_buff = nullptr; + int* ptr_lock = nullptr; // semaphore + TmaReduceSum tma_reduce_sum; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueParams epilogue; + MainloopParams mainloop_params; + typename TileScheduler::Params tile_scheduler; + int split_kv = -1; + int* ptr_split_kv = nullptr; + bool is_fused_reduction = false; + }; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + //workspace = nullptr; // let's get an error if one of these needs workspace + + auto [H, K, D, B] = args.problem_shape; + auto [L, R] = D; + + int paged_B = B; + int paged_K = K; + if (args.mainloop.ptr_page_table != nullptr) { + paged_B = args.mainloop.page_count; + paged_K = args.mainloop.page_size; + } + + auto params_qk_latent = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, L, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_latent_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, L, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_rope = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, R, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + auto params_qk_rope_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, R, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + + auto stride_c_latent_transpose = select<1,0,2>(args.mainloop.stride_c_latent); + auto params_pv_latent = CollectiveMmaPV::to_underlying_arguments( + make_shape(H, L, paged_K, paged_B), + typename CollectiveMmaPV::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, // dummy, never used + args.mainloop.ptr_c_latent, stride_c_latent_transpose, + }, nullptr); + + MainloopParams mainloop_params { + params_qk_latent.tma_load_a, + params_qk_rope.tma_load_a, + params_qk_latent_paged.tma_load_b, + params_qk_rope_paged.tma_load_b, + params_pv_latent.tma_load_b + }; + + EpilogueParams epilogue_params; + + epilogue_params.ptr_o = args.epilogue.ptr_o; + epilogue_params.stride_o = args.epilogue.stride_o; + epilogue_params.ptr_lse = args.epilogue.ptr_lse; + epilogue_params.stride_lse = args.epilogue.stride_lse; + epilogue_params.output_scale = args.epilogue.output_scale; + epilogue_params.tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(recast_ptr(args.epilogue.ptr_o), make_layout(make_shape(H, L, B), args.epilogue.stride_o)), SmemLayout{}); + + if (!args.is_fused_reduction && args.split_kv > 1) { + ElementAcc* ptr_o_acc = reinterpret_cast(workspace); + ElementLSE* ptr_lse_acc = reinterpret_cast(ptr_o_acc + H * L * args.split_kv * B); + epilogue_params.ptr_o_acc = ptr_o_acc; + epilogue_params.ptr_lse_acc = ptr_lse_acc; + + epilogue_params.stride_o_acc = make_tuple(static_cast(0 + L) * args.split_kv, _1{}, static_cast(0 + H * L) * args.split_kv); + epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv); + } else if (args.is_fused_reduction && args.split_kv > 1) { + ElementLSE* ptr_lse_exchange_buff = reinterpret_cast(workspace); + epilogue_params.ptr_lse_exchange_buff = ptr_lse_exchange_buff; + int* ptr_lse_max_exchange_buff = reinterpret_cast(ptr_lse_exchange_buff + H * B); + epilogue_params.ptr_lse_max_exchange_buff = ptr_lse_max_exchange_buff; + int* ptr_lock = ptr_lse_max_exchange_buff + H * B; + epilogue_params.ptr_lock = ptr_lock; + } + + return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params, + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), + args.split_kv, args.ptr_split_kv, args.is_fused_reduction}; + } + + static size_t get_workspace_size(Arguments const& args) { + ProblemShape problem_shape = args.problem_shape; + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + auto split_kv = args.split_kv; + size_t workspace_size {0}; + if (args.is_fused_reduction && args.split_kv > 1) { + // one exchange buffer for LSE max and another buffer for total LSE + // two locks per batch, frist lock is for CTA0 / H=0..63 and the second is for CTA1 / H=64..127 + workspace_size = H * B * (sizeof(int) + sizeof(ElementLSE)) + 2 * B * sizeof(int); + } else if (!args.is_fused_reduction && args.split_kv > 1) { + workspace_size = (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B; + } + return workspace_size; + } + static Status initialize_workspace( + Arguments const& args, void* ws, cudaStream_t stream) { + auto workspace_size = get_workspace_size(args); + if (args.is_fused_reduction && args.split_kv > 1) { + auto result = cudaMemsetAsync(ws, 0, workspace_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaMemsetAsync() returned error: " + << cudaGetErrorString(result)); + return Status::kErrorInternal;; + } + } + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static bool can_implement(Arguments const& args) { + if (kIsCpAsync) { + if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) { + std::cerr << __FILE__ << "(" << __LINE__ << "): cpasync page size pow2\n"; + return false; + } + if (args.mainloop.page_size > TileShapeS{}) { + std::cerr << __FILE__ << "(" << __LINE__ << "): cpasync page size too big\n"; + return false; + } + } + else { + if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) { + std::cerr << __FILE__ << "(" << __LINE__ << "): tma page size off\n"; + return false; + } + } + if (get<0>(args.problem_shape) != 128) { + std::cerr << __FILE__ << "(" << __LINE__ << "): heads off\n"; + return false; + } + if (get<1>(args.problem_shape) <= 0) { + std::cerr << __FILE__ << "(" << __LINE__ << "): heads off\n"; + return false; + } + if (args.split_kv <= 0) { + std::cerr << __FILE__ << "(" << __LINE__ << "): split-k off\n"; + return false; + } + if (args.is_fused_reduction && args.split_kv > 1) { + if (2 * args.split_kv > args.hw_info.sm_count || + std::is_same_v) { + return false; + } + } + return true; + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + + TileScheduler tile_scheduler(params.tile_scheduler); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + int cta_coord_v = cta_rank_in_cluster % size<0>(AtomThrShapeMNK{}); + bool is_mma_leader_cta = cta_coord_v == 0; + + if (role == WarpRole::kLoad && lane_predicate && ! kIsCpAsync) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent_transpose.get_tma_descriptor()); + } + SharedStorage& shared_storage = *reinterpret_cast(smem_raw); + + typename PipelineLoadQK::Params pipeline_load_qk_params; + if (role == WarpRole::kLoad) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Producer; + } + if (role == WarpRole::kMma) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Consumer; + } + if constexpr (kIsCpAsync) { + // we can make our life easier by unconditionally loading blocks + // since we know it'll always be legal + pipeline_load_qk_params.producer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + } + else { + pipeline_load_qk_params.is_leader = lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta; + pipeline_load_qk_params.transaction_bytes = kTransactionsBytesLoadQK; + } + pipeline_load_qk_params.initializing_warp = 0; + PipelineLoadQK pipeline_load_qk(shared_storage.pipelines.load_qk, pipeline_load_qk_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineS::Params pipeline_mma_s_params; + if (role == WarpRole::kMma) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_s_params.initializing_warp = 1; + PipelineS pipeline_mma_s( + shared_storage.pipelines.mma_s, + pipeline_mma_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineP::Params pipeline_p_mma_params; + if (role == WarpRole::kMma) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Consumer; + } + if (role == WarpRole::kCompute) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Producer; + } + pipeline_p_mma_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_p_mma_params.consumer_arv_count = 1; + pipeline_p_mma_params.initializing_warp = 2; + PipelineP pipeline_p_mma( + shared_storage.pipelines.p_mma, + pipeline_p_mma_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineO::Params pipeline_mma_o_params; + if (role == WarpRole::kMma) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_o_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_o_params.initializing_warp = 3; + PipelineO pipeline_mma_o( + shared_storage.pipelines.mma_o, + pipeline_mma_o_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelinePT::Params pipeline_pt_params; + if (role == WarpRole::kLoad) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Consumer; + } + if (role == WarpRole::kLoadPageTable) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Producer; + } + pipeline_pt_params.consumer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp; + pipeline_pt_params.producer_arv_count = cutlass::NumThreadsPerWarp; + pipeline_pt_params.initializing_warp = 4; + PipelinePT pipeline_page_table( + shared_storage.pipelines.load_page_table, + pipeline_pt_params); + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_qk.init_masks(ClusterShape{}); // do we need an update here for 2Sm? + pipeline_mma_s.init_masks(ClusterShape{}); + pipeline_p_mma.init_masks(ClusterShape{}); + pipeline_mma_o.init_masks(ClusterShape{}); + + typename PipelineLoadQK::PipelineState pipeline_load_qk_consumer_state; + typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = cutlass::make_producer_start_state(); + + typename PipelineS::PipelineState pipeline_mma_s_consumer_state; + typename PipelineS::PipelineState pipeline_mma_s_producer_state = cutlass::make_producer_start_state(); + + typename PipelineP::PipelineState pipeline_p_mma_consumer_state; + typename PipelineP::PipelineState pipeline_p_mma_producer_state = cutlass::make_producer_start_state(); + + typename PipelineO::PipelineState pipeline_mma_o_consumer_state; + typename PipelineO::PipelineState pipeline_mma_o_producer_state = cutlass::make_producer_start_state(); + + typename PipelinePT::PipelineState pipeline_pt_consumer_state; + typename PipelinePT::PipelineState pipeline_pt_producer_state = cutlass::make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + if (role == WarpRole::kLoadPageTable) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_page_table( + blk_coord, + problem_shape, + params.mainloop, + shared_storage.tensors, + pipeline_page_table, pipeline_pt_producer_state, + local_split_kv + ); + } + } + else if (role == WarpRole::kLoad) { + if constexpr (kIsCpAsync) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_cpasync( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv, + /* must be shared pipe */ + pipeline_page_table, pipeline_pt_consumer_state + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + if (params.mainloop.ptr_page_table != nullptr) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + } + } + else if (role == WarpRole::kMma) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + if (is_mma_leader_cta) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + mma(blk_coord, + problem_shape, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_mma_s, pipeline_mma_s_producer_state, + pipeline_p_mma, pipeline_p_mma_consumer_state, + pipeline_mma_o, pipeline_mma_o_producer_state, + local_split_kv + ); + } + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait(); + + //uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + //tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + else if (role == WarpRole::kCompute) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + compute( + blk_coord, + problem_shape, + params.mainloop, // for softmax_scale + params.epilogue, + shared_storage.tensors, // for smem_comm + pipeline_mma_s, pipeline_mma_s_consumer_state, + pipeline_p_mma, pipeline_p_mma_producer_state, + pipeline_mma_o, pipeline_mma_o_consumer_state, + local_split_kv, + params.is_fused_reduction + ); + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + } + + cute::cluster_sync(); + cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + if (role == WarpRole::kMma) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } +#endif + } + + template + CUTLASS_DEVICE void load_page_table( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_producer_state, int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + int batch_coord = get<2>(blk_coord); + + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(mainloop_args.page_count, B), + mainloop_args.stride_page_table); + auto mPT = mPT_l(_, batch_coord); + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + auto page_size = Pow2{mainloop_args.page_size}; + auto pages_per_tile = Pow2{TileShapeS{} / page_size}; + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp; + + for (; k_tile_count > 0; ++k_index, --k_tile_count) { + pipeline_page_table.producer_acquire(pipeline_pt_producer_state); + + // assume a single warp + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TileShapeS{}; i += cutlass::NumThreadsPerWarp) { + int idx = i + thread_idx; + bool guard = idx < pages_per_tile; + int smem_idx = pipeline_pt_producer_state.index() * TileShapeS::value + idx; + int pt_idx = pages_per_tile * k_index + idx; + + cutlass::arch::cp_async_zfill( + &shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard + ); + } + + pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_pt_producer_state; + } + } + + + struct Gather { + int& page_table_stage; + Pow2 pages_per_tile; + const int * __restrict__ smem_page_table; + + CUTLASS_DEVICE int operator()(int idx) const { + return smem_page_table[page_table_stage * TileShapeS::value + idx % pages_per_tile]; + } + + CUTLASS_DEVICE friend void print(Gather const&) { + printf(""); + } + + }; + + + template + CUTLASS_DEVICE void load_cpasync( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load, + typename PipelineLoadQK::PipelineState& pipeline_load_producer_state, + int const& split_kv, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_consumer_state) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using X = Underscore; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // partition all tensors + auto mQL = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_latent), make_shape(H, D_latent, B), mainloop_args.stride_q_latent); + auto mQR = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_rope), make_shape(H, D_rope, B), mainloop_args.stride_q_rope); + + int paged_B = mainloop_args.page_count; + auto paged_K = Pow2{mainloop_args.page_size}; + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + int batch_coord = get<2>(blk_coord); + auto mPT = mPT_l(_, batch_coord); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto make_copy_for = [](auto sT) { + auto rT_a = sT.layout()(_, _, _, _0{}); + auto rT = make_ordered_layout(shape(rT_a), stride(rT_a)); + auto threads = Int{}; + auto values = Int{}; + return make_cotiled_copy( + Copy_Atom, Element>{}, + make_ordered_layout( + make_shape(threads, values), + make_stride(_1{}, _0{})), + rT); + }; + + // like cute::copy, but makes sure we do all page table lookups first + auto copy_split = [](auto atom, auto src, auto dst) { + auto src_v = group_modes<1, rank_v>(src); + auto dst_v = group_modes<1, rank_v>(dst); + + auto src_v_ptrs = make_tensor(size<1>(src_v)); + for (int i = 0; i < size<1>(src_v); i++) { + src_v_ptrs(i) = &src_v(_0{}, i); + } + + + for (int i = 0; i < size<1>(src_v); i++) { + auto src_v_i = make_tensor( + make_gmem_ptr(src_v_ptrs(i)), + make_shape(shape<0>(src_v)), + make_stride(make_stride(_1{}, _0{})) + ); + atom.call(src_v_i, dst_v(_, i)); + } + }; + + auto tiled_copy_q = make_copy_for(sQ); + auto tiled_copy_kc = make_copy_for(sKC); + auto tiled_copy_vc = make_copy_for(sVC); + + auto thr_copy_q = tiled_copy_q.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_kc = tiled_copy_kc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_vc = tiled_copy_vc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQL = thr_copy_q.partition_S(tSgQL); + auto tQgQR = thr_copy_q.partition_S(tSgQR); + + auto tKCsKC = thr_copy_kc.partition_D(sKC); + auto tVCsVC = thr_copy_vc.partition_D(sVC); + + auto pipeline_pt_release_state = pipeline_pt_consumer_state; + + int page_table_stage = -1; + Pow2 pages_per_tile{TileShapeS{} / paged_K}; + const int * __restrict__ smem_page_table = shared_tensors.smem_page_table.begin(); + Gather gather{page_table_stage, pages_per_tile, smem_page_table}; + + auto mCL = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))), get<1>(mainloop_args.stride_c_latent))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mKR = make_tensor( + make_gmem_ptr(mainloop_args.ptr_k_rope), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_k_rope), example::CustomStride(gather, get<2>(mainloop_args.stride_k_rope))), get<1>(mainloop_args.stride_k_rope))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mCLT = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(_1{}, make_shape(paged_K, paged_B)), + make_stride(get<1>(mainloop_args.stride_c_latent), make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(D_latent, paged_K * paged_B))}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + auto tKCgCL = thr_copy_kc.partition_S(tSgCL); + auto tKCgKR = thr_copy_kc.partition_S(tSgKR); + auto tVCgCLT = thr_copy_vc.partition_S(tOgCLT); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + auto& pipeline_acquire_state = pipeline_load_producer_state; + auto pipeline_commit_state = pipeline_acquire_state; + int pipeline_offset = 0; + + for (int i = 0; i < StagesPV; i++) { + cutlass::arch::cp_async_fence(); + } + + auto load_stage = [&](auto fn) { + pipeline_load.producer_acquire(pipeline_acquire_state); + fn(pipeline_acquire_state.index()); + cutlass::arch::cp_async_fence(); + + ++pipeline_acquire_state; + ++pipeline_offset; + + if (pipeline_offset == StagesPV - 1) { + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + }; + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQL(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, i)); + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQR(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, IterationsQKLatent + i)); + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + k_index += 1; + k_tile_count -= 1; + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + while (pipeline_offset > 0) { + cutlass::arch::cp_async_fence(); + + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + + cutlass::arch::cp_async_wait<0>(); + + } + + + template + CUTLASS_DEVICE void load_tma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_producer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + using X = Underscore; + + // partition all tensors + auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor(make_shape(H, D_latent, B)); + auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor(make_shape(H, D_rope, B)); + + int paged_B = B; + int paged_K = K; + if constexpr (kIsPaged) { + paged_B = mainloop_args.page_count; + paged_K = mainloop_args.page_size; + } + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + auto mCL = mainloop_params.tma_load_c_latent.get_tma_tensor(make_shape(paged_K, D_latent, paged_B)); + auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor(make_shape(paged_K, D_rope, paged_B)); + + auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor(make_shape(D_latent, paged_K, paged_B)); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto [tQLgQL_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQL)); + + auto [tQRgQR_mkl, tQsQ_ignore] = tma_partition( + mainloop_params.tma_load_q_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQR)); + + auto [tCLgCL_nkl, tKCsKC] = tma_partition( + mainloop_params.tma_load_c_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgCL)); + + auto [tKRgKR_nkl, tKCsKC_ignore] = tma_partition( + mainloop_params.tma_load_k_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgKR)); + + auto [tCLTgCLT_nkl, tVCsVC] = tma_partition( + mainloop_params.tma_load_c_latent_transpose, _0{}, make_layout(_1{}), + group_modes<0,3>(sVC), group_modes<0,3>(tOgCLT)); + + uint16_t mcast_mask = 0; + + int batch_coord = get<2>(blk_coord); + Tensor tQLgQL = tQLgQL_mkl(_, _, _, batch_coord); + Tensor tQRgQR = tQRgQR_mkl(_, _, _, batch_coord); + + auto mPT = mPT_l(_, batch_coord); + + Tensor tCLgCL = tCLgCL_nkl(_, _, _, _); + Tensor tKRgKR = tKRgKR_nkl(_, _, _, _); + + // careful: stage and k are swapped here! + Tensor tCLTgCLT = tCLTgCLT_nkl(_, _, _, _); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), tQLgQL(_, _0{}, i), tQsQ(_, i)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), tQRgQR(_, _0{}, i), tQsQ(_, i + IterationsQKLatent)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + // perform K load + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + // prefetch next K load to keep busy while we transpose-load from cache + const int kPrefetchDistance = 1; + for (int i = 0; i < IterationsQKLatent; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + for (int i = 0; i < IterationsQKRope; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + // perform V load (k_idx - 1) + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices! + // note we are off-by-one on k_index + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + + k_index += 1; + k_tile_count -= 1; + } + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices + // note we are off-by-one on k_index + + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + } + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_consumer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_consumer_state, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_producer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_consumer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // mma init + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}); + + Tensor tSrQ = TiledMmaQK::make_fragment_A(sQ); + Tensor tSrKC = TiledMmaQK::make_fragment_B(sKC); + Tensor tOrP = TiledMmaPV::make_fragment_A(sP); + Tensor tOrVC = TiledMmaPV::make_fragment_B(sVC); + + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::Zero; + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + + // Mma S0 S1 O0 S2 O1 ... Sn On-1 On + // S0 ownership -- ----- -- -- + // S1 ownership -- ----- ---- + // O ownership -- -- ---- -- + + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + + --k_tile_count; + } + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + } + + + template + CUTLASS_DEVICE void softmax( + IsLastTile const& is_last_tile, + ElementAcc& row_max, + ElementAcc& row_sum, + ElementAcc& correction_factor, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + int k_index, + uint32_t tmem_s, + int smem_p_index) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaQK tiled_mma_qk; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tStS.data() = tmem_s; + + CUTE_STATIC_ASSERT_V(shape<1>(tStS) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tStS) == _1{}); + Tensor tAcc = tStS(make_coord(_,_),_0{},_0{}); + + Tensor cS = make_identity_tensor(take<0,2>(CtaShapeQK{})); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_cS = thread_t2r.partition_D(cS); + Tensor tTR_rAcc = make_tensor(shape(tTR_cS)); + + Tensor tTR_rS_frag = make_tensor(shape(tTR_rAcc)); + const int AlignmentS = 4; + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + Tensor tTR_rAcc_vec = recast>(tTR_rAcc); + Tensor tTR_rS_vec = recast>(tTR_rS_frag); + + // load s + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + if (is_last_tile) { + for (int i = 0; i < size(tTR_rAcc); i++) { + if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= get<1>(problem_shape)) { + tTR_rAcc(i) = -std::numeric_limits::infinity(); + } + } + } + + // max + ElementAcc row_max_new = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 1) { + row_max_new = ::fmax(row_max_new, tTR_rAcc(i)); + } + + // for 2x2 dp, reduce here + if constexpr (kWarpsInN > 1) { + shared_tensors.smem_exchange[threadIdx.x] = row_max_new; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]); + } + + // find correction factor + ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast(M_LOG2E); + correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new)); + row_max = row_max_new; + + // softmax + ElementAcc row_max_scale_log2 = row_max * softmax_scale_log2; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2); + } + + // quantize + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc_vec); i++) { + tTR_rS_vec(i) = epilogue_op(tTR_rAcc_vec(i)); + } + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})(_, _, _, make_coord(_, smem_p_index)); + + Tensor tOcP = TiledMmaPV{}.get_slice(_0{}).partition_A(cS); + + // have a mapping for each thread to coord + // find identical mapping to coords for the MMA + auto l = make_ordered_layout(make_shape(make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{}))); + auto sP_ = as_position_independent_swizzle_tensor(sP); + copy_aligned(tTR_rS_frag, sP_.compose(l)(threadIdx.x, _)); + + // sum + row_sum *= correction_factor; + + static_assert(cute::is_same_v); + auto tTR_rAcc_float2 = recast(tTR_rAcc); + auto sums = make_tensor(_4{}); + static_assert(size(tTR_rAcc_float2) % size(sums) == 0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(sums); i++) { + sums(i) = tTR_rAcc_float2(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = size(sums); i < size(tTR_rAcc_float2); i += size(sums)) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j++) { + cute::add(sums(j), sums(j), tTR_rAcc_float2(i + j)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < size(sums); i *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j += 2*i) { + cute::add(sums(j), sums(j), sums(j+i)); + } + } + row_sum += sums(0).x + sums(0).y; + } + + + CUTLASS_DEVICE void rescale( + ElementAcc correction_factor, + uint32_t tmem_o) { + + // for b2b gemm, do nothing + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + auto store_op = TMEM::tmem_load_to_store(load_op); + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = make_tensor(make_gmem_ptr((ElementAcc*) nullptr), cta_tiler_pv, make_stride(0, 0)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto tiled_r2t = make_tmem_copy(store_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + auto thread_r2t = tiled_r2t.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + // load o + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + // multiply by correction factor + float2 correction_factor_vec = make_float2(correction_factor, correction_factor); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 2) { + float2 in = make_float2(tTR_rAcc(i + 0), tTR_rAcc(i + 1)); + float2 out; + cute::mul(out, in, correction_factor_vec); + tTR_rAcc(i + 0) = out.x; + tTR_rAcc(i + 1) = out.y; + } + + // store o + copy(tiled_r2t, tTR_rAcc, tTR_tAcc); + } + + + template + CUTLASS_DEVICE void epilogue( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + uint32_t tmem_o, + int const& split_kv) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + if (split_kv > 1) { + using ElementOutAcc = ElementAcc; + constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + + if (get<1>(cta_coord) == 0) { + if (epilogue_args.ptr_lse != nullptr) { + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + } + } + } + else { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + + if (get<1>(cta_coord) == 0) { + if (epilogue_args.ptr_lse != nullptr) { + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + } + } + } + } + + template + CUTLASS_DEVICE ElementLSE epilogue_lse_reduction( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + int const& local_split_kv) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + + constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp; + using Sync = cutlass::detail::NamedBarrierSync; + + auto wait = [](int* lock, int count) { + __threadfence(); + if (threadIdx.x == 0) { + atomicAdd(lock, 1); + while (atomicCAS(lock, count, count) != count) {}; + } + __threadfence(); + Sync::sync(); + }; + + const ElementLSE lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + Tensor mLSE_max_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_max_exchange_buff), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE_max_buff = local_tile(mLSE_max_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + int* local_lock = epilogue_args.ptr_lock + get<0>(cta_coord) + 2 * get<2>(cta_coord); + + if (! kIs2Sm || (threadIdx.x < 64)) { + atomicMax(&(gLSE_max_buff(threadIdx.x)), __float2int_rn(lse)); + } + wait(local_lock, local_split_kv); + + auto global_lse_max = static_cast(gLSE_max_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x)); + + Tensor mLSE_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_exchange_buff), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE_buff = local_tile(mLSE_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + if (! kIs2Sm || (threadIdx.x < 64)) { + atomicAdd(&(gLSE_buff(threadIdx.x)), expf(lse - global_lse_max)); + } + wait(local_lock, 2*local_split_kv); + + const auto sum_lse = gLSE_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x); + const auto global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : + cutlass::fast_log(sum_lse) + global_lse_max; + const auto lse_scale = expf(lse - global_lse); + + if (epilogue_args.ptr_lse != nullptr) { + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + // write out the global LSE + if (! kIs2Sm || (threadIdx.x < 64)) { + gLSE(threadIdx.x) = global_lse; + } + } + return lse_scale; + } + + + template + CUTLASS_DEVICE void epilogue_reduction( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + int const& local_split_kv, + ElementLSE const& lse_scale) { + + constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp; + using Sync = cutlass::detail::NamedBarrierSync; + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaPV tiled_mma_pv; + Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + + using EpilogueLinearCombination = cutlass::epilogue::thread::LinearCombination; + EpilogueLinearCombination epilogue_op({epilogue_args.output_scale / row_sum * lse_scale}); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < IterationsPV_N; ++k) { + auto cta_coord = replace<1>(blk_coord, k); + + uint32_t tmem_o = uint32_t(TmemAllocation::kO0) + k * uint32_t(TmemAllocation::kSizeAccO); + tOtO.data() = tmem_o; + + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); + Tensor gO = local_tile(mO, TileShapeAcc{}, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + Tensor sO = make_tensor(make_smem_ptr(reinterpret_cast(shared_tensors.smem_acc.begin())), SmemLayout{}); + Tensor tTR_sO = thread_t2r.partition_D(sO); + + Sync::sync(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_sO(i) = epilogue_op(tTR_rAcc(i)); + } + tma_store_fence(); + Sync::sync(); + + auto tma_reduce_sum_per_cta = epilogue_args.tma_reduce_sum.get_slice(_0{}); + auto gmem_tensor_coord = epilogue_args.tma_reduce_sum.get_tma_tensor(shape(mO)); + auto gmem_tensor_coord_per_cta = local_tile(gmem_tensor_coord, TileShapeAcc{}, take<0,3>(cta_coord)); + if (threadIdx.x % kNumThreads == 0) { + copy(epilogue_args.tma_reduce_sum, + tma_reduce_sum_per_cta.partition_S(sO), + tma_reduce_sum_per_cta.partition_D(gmem_tensor_coord_per_cta)); + tma_store_arrive(); + } + tma_store_wait<0>(); + } + } + + template + CUTLASS_DEVICE void compute( + CtaCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_consumer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_producer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_consumer_state, + int const& split_kv, + bool const& is_fused_reduction) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(cta_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + + // if we return early, we have to make sure we release the load warp + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue + ).arrive(); + + return; + } + int k_index_final = k_tile_total - 1; + + ElementAcc row_max = -std::numeric_limits::infinity(); + ElementAcc row_sum = 0; + ElementAcc correction_factor = 1; + + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + // softmax s0 -> p0 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + k_index += 1; + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + // softmax s1 -> p1 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + + // rescale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + rescale(correction_factor, uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO)); + } + + cutlass::arch::fence_view_async_tmem_store(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + + --k_tile_count; + k_index += 1; + } + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + + if constexpr (kWarpsInN > 1) { + // reduce row_sum if needed (for 2x2 dp) + shared_tensors.smem_exchange[threadIdx.x] = row_sum; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_sum += shared_tensors.smem_exchange[peer_index]; + } + + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive(); + + const int actual_split_kv = ceil_div(k_tile_total, k_tile_per_cta); + if (!is_fused_reduction || actual_split_kv == 1) { + // epilogue + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + epilogue( + row_max, row_sum, + replace<1>(cta_coord, j), problem_shape, + mainloop_args, epilogue_args, shared_tensors, + uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), + actual_split_kv + ); + } + } else { + const ElementLSE lse_scale = + epilogue_lse_reduction( + row_max, row_sum, + cta_coord, + problem_shape, + mainloop_args, epilogue_args, + actual_split_kv + ); + + epilogue_reduction(row_max, row_sum, + cta_coord, + problem_shape, + mainloop_args, epilogue_args, + shared_tensors, + actual_split_kv, + lse_scale + ); + } + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..489f62f85368577a9df8471df19e98f4a23bc0c6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp @@ -0,0 +1,160 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaIndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z); + } + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_split_kv; + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = size<0>(cluster_shape); + int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */; + num_blocks *= split_kv; /* Maximum Split KV*/ + + return Params { + num_blocks, + { num_m_blocks}, { get<3>(problem_shape) }, {split_kv}, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, n_split_kv; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_split_kv(block_decode, n_split_kv, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + return make_coord(m_block, _0{}, bidb, n_split_kv); + } + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..465a9871ab8040c17ac78000d6f580a7045f8e2d --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp @@ -0,0 +1,410 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cute/tensor.hpp" +#include "collective/fmha_fusion.hpp" + +using namespace cutlass::fmha::collective; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, /* class TensorDK, class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dQ_kernel( + ProblemShape problem_shape_in, + TensorQ mQ_in, TensorK mK_in, TensorV mV_in, + TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, + TensorDQ mDQ_in, /* TensorDK mDK, TensorDV mDV, */ + Fusion fusion) { + + using namespace cute; + using namespace cutlass::fmha::collective; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in)); + + for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { + auto [problem_shape, offset] = apply_variable_length_offset( + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, _0{},idx2crd(idx_L, get<4>(problem_shape_in))) + ); + // problem_shape = problem_shape_in; + // offset = repeat_like(problem_shape_in, _0{}); + auto mQ = domain_offset(select<0,2,4>(offset), mQ_in); + auto mK = domain_offset(select<1,2,4>(offset), mK_in); + auto mV = domain_offset(select<1,3,4>(offset), mV_in); + auto mO = domain_offset(select<0,3,4>(offset), mO_in); + auto mLSE = domain_offset(select<0,4>(offset), mLSE_in); + auto mDO = domain_offset(select<0,3,4>(offset), mDO_in); + auto mDQ = domain_offset(select<0,2,4>(offset), mDQ_in); + for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) { + for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { + ElementAcc acc_qk = 0; + ElementAcc acc_dov = 0; + ElementAcc acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + // acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + // acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } // for idx_D0 + + for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) { + acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L); + acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_K] = static_cast(expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } // for idx_K + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { + ElementAcc acc = 0; + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + ElementAcc rK = mK(idx_K, idx_D, idx_L); + ElementAcc rDS = mS[idx_K]; + acc += rDS * rK; + } + mDQ(idx_Q, idx_D, idx_L) = static_cast(acc); + } // for idx_D + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, */ class TensorDK, /* class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dK_kernel( + ProblemShape problem_shape_in, + TensorQ mQ_in, TensorK mK_in, TensorV mV_in, + TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, + /* TensorDQ mDQ_in, */ TensorDK mDK_in, /* TensorDV mDV_in, */ + Fusion fusion) { + + using namespace cute; + using namespace cutlass::fmha::collective; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in)); + + auto [H, B] = get<4>(problem_shape_in); + auto [H_R, H_K] = H; + + for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) { + auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B)); + auto [problem_shape, offset] = apply_variable_length_offset( + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B)) + ); + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset; + + auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in); + auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in); + auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in); + auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in); + auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in); + auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in); + auto mDK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mDK_in); + + for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) { + ElementAcc acc_dk = 0; + for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) { + auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B); + for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) { + ElementAcc acc_qk = 0; + ElementAcc acc_dov = 0; + ElementAcc acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < D; idx_D0++) { + ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB); + ElementAcc rK = mK(idx_K, idx_D0, coord_HB); + acc_qk += rQ * rK; + } // for idx_D0 + + for (int idx_D1 = 0; idx_D1 < D_VO; idx_D1++) { + ElementAcc rDO = mDO(idx_Q, idx_D1, coord_HB); + ElementAcc rV = mV(idx_K, idx_D1, coord_HB); + ElementAcc rO = mO(idx_Q, idx_D1, coord_HB); + acc_dov += rDO * rV; + acc_doo += rDO * rO ; + } + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)) * softmax_scale * (acc_dov - acc_doo)); + } // for idx_Q + + __syncthreads(); + + int idx_D = threadIdx.x; + if (idx_D < D) { + for (int idx_Q = 0; idx_Q < Q; idx_Q++) { + ElementAcc rQ = mQ(idx_Q, idx_D, coord_HB); + ElementAcc rDS = mS[idx_Q]; + acc_dk += rDS * rQ; + } + } + __syncthreads(); + } // for idx_H_R + + int idx_D = threadIdx.x; + if (idx_D < D) { + auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B); + mDK(idx_K, idx_D, coord_HB) = static_cast(acc_dk); + } + } // for idx_K + } // for idx_HB +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, class TensorDK, */ class TensorDV, + class Fusion +> +void __global__ fmha_bwd_reference_dV_kernel( + ProblemShape problem_shape_in, + TensorQ mQ_in, TensorK mK_in, TensorV mV_in, + TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, + /* TensorDQ mDQ_in, TensorDK mDK_in, */ TensorDV mDV_in, + Fusion fusion) { + + using namespace cute; + using namespace cutlass::fmha::collective; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in)); + + auto [H, B] = get<4>(problem_shape_in); + auto [H_R, H_K] = H; + + for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) { + auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B)); + auto [problem_shape, offset] = apply_variable_length_offset( + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B)) + ); + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset; + + auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in); + auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in); + auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in); + auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in); + auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in); + auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in); + auto mDV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mDV_in); + + for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) { + ElementAcc acc_dv = 0; + for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) { + auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B); + for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) { + ElementAcc acc_qk = 0; + + for (int idx_D0 = 0; idx_D0 < D; idx_D0++) { + ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB); + ElementAcc rK = mK(idx_K, idx_D0, coord_HB); + acc_qk += rQ * rK; + } // for idx_D0 + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB))); + } // for idx_Q + + __syncthreads(); + + int idx_D_VO = threadIdx.x; + if (idx_D_VO < D_VO) { + for (int idx_Q = 0; idx_Q < Q; idx_Q++) { + ElementAcc rDO = mDO(idx_Q, idx_D_VO, coord_HB); + ElementAcc rP = mS[idx_Q]; + acc_dv += rP * rDO; + } + } // for idx_D + + __syncthreads(); + } // for idx_H_R + + int idx_D_VO = threadIdx.x; + if (idx_D_VO < D_VO) { + auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B); + mDV(idx_K, idx_D_VO, coord_HB) = static_cast(acc_dv); + } + } // for idx_K + } // for idx_L +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dQ( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/ + Fusion fusion) { + + using namespace cute; + + dim3 grid(size<0>(mDQ), size<2>(mDQ), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * sizeof(typename TensorDQ::value_type); + fmha_bwd_reference_dQ_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dK( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/ + Fusion fusion) { + + using namespace cute; + + auto [K, D, HB] = mDK.shape(); + auto [H, B] = HB; + auto [H_R, H_K] = H; + dim3 grid(K, H_K * B, 1); + dim3 block(std::max(D, 256)); + int shared_mem = size<0>(mDO) * sizeof(typename TensorDK::value_type); + fmha_bwd_reference_dK_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/ + class Fusion +> +void fmha_bwd_reference_dV( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/ + Fusion fusion) { + + using namespace cute; + + auto [K, D_VO, HB] = mDV.shape(); + auto [H, B] = HB; + auto [H_R, H_K] = H; + dim3 grid(K, H_K * B, 1); + dim3 block(std::max(D_VO, 256)); + int shared_mem = size<0>(mDO) * sizeof(typename TensorDV::value_type); + fmha_bwd_reference_dV_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, class TensorDK, class TensorDV, + class Fusion +> +void fmha_bwd_reference( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, TensorDK mDK, TensorDV mDV, + Fusion fusion) { + + fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); + fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); + fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..40239c564b001b6571bda5efcc151395d239f4f1 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp @@ -0,0 +1,185 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ElementAcc, + class ProblemShape, + class TensorQ, + class TensorNewK, + class TensorNewV, + class TensorCacheK, + class TensorCacheV, + class TensorO +> +void __global__ fmha_fwd_gen_reference_kernel( + ProblemShape problem_shape, + const int* seqlen_kv, const int* cache_batch_idx, + TensorQ mQ, TensorNewK mNewK, TensorNewV mNewV, + TensorCacheK mCacheK, TensorCacheV mCacheV, TensorO mO) { + + using namespace cute; + extern __shared__ char mS_mem[]; + ElementAcc* mS = reinterpret_cast(mS_mem); + + float scale = 1.0f / std::sqrt(float(get<2>(problem_shape))); + + if (mNewK.data() != nullptr) { + // 1. copy in new_k to cache + for (int idx_h = blockIdx.x; idx_h < size<3,0,1>(problem_shape); idx_h += gridDim.x) { + for (int idx_b = blockIdx.z; idx_b < size<3,1>(problem_shape); idx_b += gridDim.z) { + int idx_b_kv = cache_batch_idx != nullptr ? cache_batch_idx[idx_b] : idx_b; + for (int idx_d = threadIdx.x; idx_d < size<2>(problem_shape); idx_d += blockDim.x) { + mCacheK(seqlen_kv[idx_b], idx_d, make_coord(make_coord(_0{}, idx_h), idx_b_kv)) = + mNewK(_0{}, idx_d, make_coord(make_coord(_0{}, idx_h), idx_b)); + mCacheV(seqlen_kv[idx_b], idx_d, make_coord(make_coord(_0{}, idx_h), idx_b_kv)) = + mNewV(_0{}, idx_d, make_coord(make_coord(_0{}, idx_h), idx_b)); + } + } + } + } + + // 2. compute attention + for (int idx_h_kv = blockIdx.x; idx_h_kv < size<3,0,1>(problem_shape); idx_h_kv += gridDim.x) { + for (int idx_h_qo = blockIdx.y; idx_h_qo < size<3,0,0>(problem_shape); idx_h_qo += gridDim.y) { + int idx_h = idx_h_qo + size<3,0,0>(problem_shape) * idx_h_kv; + for (int idx_b = blockIdx.z; idx_b < size<3,1>(problem_shape); idx_b += gridDim.z) { + int idx_b_kv = cache_batch_idx != nullptr ? cache_batch_idx[idx_b] : idx_b; + const int kDim = 128; + ElementAcc reg_o[kDim] = {0}; + ElementAcc row_max = -INFINITY; + ElementAcc row_sum = 0; + auto iteration = [&](auto const& tK, auto const& tV) { + ElementAcc reg_s = 0; + for (int idx_d = 0; idx_d < kDim; idx_d++) { + ElementAcc eQ = mQ(_0{}, idx_d, make_coord(idx_h, idx_b)); + ElementAcc eK = tK(idx_d); + reg_s += eQ * eK; + } + + ElementAcc old_row_max = row_max; + row_max = std::max(row_max, reg_s); + + ElementAcc adjustment = std::exp(scale * (old_row_max - row_max)); + row_sum *= adjustment; + for (int idx_d = 0; idx_d < kDim; idx_d++) { + reg_o[idx_d] *= adjustment; + } + + ElementAcc reg_p = std::exp(scale * (reg_s - row_max)); + row_sum += reg_p; + + for (int idx_d = 0; idx_d < kDim; idx_d++) { + ElementAcc eV = tV(idx_d); + reg_o[idx_d] += reg_p * eV; + } + }; + + for (int idx_s = threadIdx.x; idx_s < seqlen_kv[idx_b]; idx_s += blockDim.x) { + iteration(mCacheK(idx_s, _, make_coord(idx_h, idx_b_kv)), mCacheV(idx_s, _, make_coord(idx_h, idx_b_kv))); + } + + if (mNewK.data() != nullptr && threadIdx.x == 0) { + iteration(mNewK(_0{}, _, make_coord(idx_h, idx_b)), mNewV(_0{}, _, make_coord(idx_h, idx_b))); + } + + mS[threadIdx.x] = row_max; + __syncthreads(); + float old_row_max = row_max; + for (int i = 0; i < blockDim.x; i++) { + row_max = std::max(row_max, mS[i]); + } + __syncthreads(); + + ElementAcc adjustment = std::exp(scale * (old_row_max - row_max)); + row_sum *= adjustment; + for (int idx_d = 0; idx_d < kDim; idx_d++) { + reg_o[idx_d] *= adjustment; + } + mS[threadIdx.x] = row_sum; + __syncthreads(); + + row_sum = 0; + for (int i = 0; i < blockDim.x; i++) { + row_sum += mS[i]; + } + __syncthreads(); + for (int idx_d = 0; idx_d < kDim; idx_d++) { + mS[idx_d] = 0; + } + __syncthreads(); + + for (int idx_d = 0; idx_d < kDim; idx_d++) { + reg_o[idx_d] /= row_sum; + atomicAdd(&mS[idx_d], reg_o[idx_d]); + } + + __syncthreads(); + for (int idx_d = threadIdx.x; idx_d < kDim; idx_d += blockDim.x) { + mO(_0{}, idx_d, make_coord(idx_h, idx_b)) = static_cast(mS[idx_d]); + } + } + } + } +} + +template< + class ElementAcc, + class ProblemShape, + class TensorQ, + class TensorNewK, + class TensorNewV, + class TensorCacheK, + class TensorCacheV, + class TensorO +> +void fmha_fwd_gen_reference( + ProblemShape problem_shape, + const int* seqlen_kv, const int* cache_batch_idx, + TensorQ mQ, TensorNewK mNewK, TensorNewV mNewV, + TensorCacheK mCacheK, TensorCacheV mCacheV, TensorO mO) { + + using namespace cute; + + dim3 grid(get<3,0,1>(problem_shape), get<3,0,0>(problem_shape), get<3,1>(problem_shape)); + dim3 block(128); + int shared_mem = int(sizeof(ElementAcc)) * std::max(128, block.x); + assert(get<2>(problem_shape) == 128); + fmha_fwd_gen_reference_kernel<<>>( + problem_shape, seqlen_kv, cache_batch_idx, + mQ, mNewK, mNewV, mCacheK, mCacheV, mO + ); +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d674ee95d2a7e30ef9c870506aaa414066e3df70 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp @@ -0,0 +1,195 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "collective/fmha_fusion.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShapeIn, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Mask +> +void __global__ fmha_reference_kernel( + ProblemShapeIn problem_shape_in, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Mask mask) { + + using namespace cute; + using namespace cutlass::fmha::collective; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + ElementAccumulator* mS = reinterpret_cast(mS_mem); + + ElementAccumulator softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mQ))); + + auto id = make_identity_tensor(make_shape(1, 1)); + + for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape_in); idx_Q += gridDim.x) { + + auto coord_L = idx2crd(idx_L, shape<4>(problem_shape_in)); + auto get_coord_in = [&]() { + if constexpr (rank_v(ProblemShapeIn{}))> == 2) { + return cute::make_tuple(idx_Q, _0{}, cute::make_tuple(_0{}, _0{}), cute::make_tuple(_0{}, _0{}), coord_L); + } else { + return cute::make_tuple(idx_Q, _0{}, _0{}, _0{}, coord_L); + } + }; + auto coord_in = get_coord_in(); + auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<4,1>(coord_in)); + + int head_qk = 0; + int head_v = 0; + if constexpr (rank_v(problem_shape))> == 2) { + // MLA case: head_qk 192, head_v = 128 + head_qk = size<2, 0>(problem_shape) + size<2, 1>(problem_shape); + head_v = size<2, 0>(problem_shape); + } else { + head_qk = size<3>(problem_shape); + head_v = head_qk; + } + + if (get<0,0>(coord) >= get<0>(problem_shape)) continue; + + int offset_Q = 0; + if constexpr (rank<0>(decltype(coord){}) == 2) { + offset_Q = get<0,1>(coord); + } + + int offset_K = 0; + if constexpr (rank<1>(decltype(coord){}) == 2) { + offset_K = get<1,1>(coord); + } + + if (get<1>(problem_shape) == 0) { + for (int idx_D = threadIdx.x; idx_D < head_qk; idx_D += blockDim.x) { + mO(idx_Q + offset_Q, idx_D, idx_L) = Element(0); + } + + if (threadIdx.x == 0 && mLSE.data() != nullptr) { + mLSE(idx_Q + offset_Q, idx_L) = -INFINITY; + } + continue; + } + + for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_D = 0; idx_D < head_qk; idx_D++) { + ElementAccumulator eQ = mQ(idx_Q + offset_Q, idx_D, idx_L); + ElementAccumulator eK = mK(idx_K + offset_K, idx_D, idx_L); + acc += eQ * eK; + } + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc; + mask.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + mS[idx_K] = frag(0); + } + + __syncthreads(); + + ElementAccumulator maxS = -std::numeric_limits::infinity(); + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + maxS = std::max(maxS, mS[idx_K]); + } + if (maxS == -std::numeric_limits::infinity()) maxS = 0; + + __syncthreads(); + + for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { + mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS)); + } + + __syncthreads(); + + ElementAccumulator sum = 0; + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + sum += mS[idx_K]; + } + + ElementAccumulator scale = 1.0f / sum; + + + for (int idx_D = threadIdx.x; idx_D < head_v; idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + ElementAccumulator eV = mV(idx_K + offset_K, idx_D, idx_L); + ElementAccumulator eK = static_cast(mS[idx_K]); + acc += eK * eV; + } + mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast(acc * scale); + } + + + if (threadIdx.x == 0 && mLSE.data() != nullptr) { + mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS; + } + + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShapeIn, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Mask +> +void fmha_reference( + ProblemShapeIn problem_shape_in, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Mask mask) { + + using namespace cute; + + dim3 grid(size<0>(mO), size<2>(mO), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * int(sizeof(typename TensorLSE::value_type)); + fmha_reference_kernel<<>>(problem_shape_in, mQ, mK, mV, mO, mLSE, mask); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c83ebdb7479159943c8edc0974c998d37052405e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorSeq, + class TensorPageTable, + class TensorQL, + class TensorQR, + class TensorCL, + class TensorKR, + class TensorO, + class TensorLSE, + class Scale +> +void __global__ fmha_mla_reference_kernel( + ProblemShape problem_shape, + TensorSeq mSeq, TensorPageTable mPT, + TensorQL mQL, TensorQR mQR, + TensorCL mCL, TensorKR mKR, + TensorO mO, TensorLSE mLSE, + Scale softmax_scale) { + + using namespace cute; + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ ElementAcc mS[]; + // ElementAcc* mS = reinterpret_cast(mS_mem); + + for (int idx_B = blockIdx.y; idx_B < B; idx_B += gridDim.y) { + if (mSeq.data() != nullptr) { + K = mSeq(idx_B); + } + + for (int idx_H = blockIdx.x; idx_H < H; idx_H += gridDim.x) { + + for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) { + ElementAcc acc = 0; + + for (int idx_D = 0; idx_D < D_latent; idx_D++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eQ = mQL(idx_H, idx_D, idx_B); + ElementAcc eK = mCL(page_idx_K, idx_D, page_idx_B); + acc += eQ * eK; + } + + for (int idx_D = 0; idx_D < D_rope; idx_D++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eQ = mQR(idx_H, idx_D, idx_B); + ElementAcc eK = mKR(page_idx_K, idx_D, page_idx_B); + acc += eQ * eK; + } + mS[idx_K] = acc; + } + + __syncthreads(); + + ElementAcc maxS = -std::numeric_limits::infinity(); + for (int idx_K = 0; idx_K < K; idx_K++) { + maxS = std::max(maxS, mS[idx_K]); + } + if (maxS == -std::numeric_limits::infinity()) maxS = 0; + + __syncthreads(); + + for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) { + mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS)); + } + + __syncthreads(); + + ElementAcc sum = 0; + for (int idx_K = 0; idx_K < K; idx_K++) { + sum += mS[idx_K]; + } + + ElementAcc o_scale = 1.0f / sum; + + for (int idx_D = threadIdx.x; idx_D < D_latent; idx_D += blockDim.x) { + ElementAcc acc = 0; + for (int idx_K = 0; idx_K < K; idx_K++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eV = mCL(page_idx_K, idx_D, page_idx_B); + ElementAcc eK = static_cast(mS[idx_K]); + acc += eK * eV; + } + mO(idx_H, idx_D, idx_B) = static_cast(acc * o_scale); + } + + if (threadIdx.x == 0) { + mLSE(idx_H, idx_B) = log(sum) + softmax_scale * maxS; + } + + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorSeq, + class TensorPageTable, + class TensorQL, + class TensorQR, + class TensorCL, + class TensorKR, + class TensorO, + class TensorLSE, + class Scale +> +void fmha_mla_reference( + ProblemShape problem_shape, + TensorSeq mSeq, TensorPageTable mPT, + TensorQL mQL, TensorQR mQR, + TensorCL mCL, TensorKR mKR, + TensorO mO, TensorLSE mLSE, + Scale scale) { + + using namespace cute; + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + dim3 grid(H, B, 1); + dim3 block(256); + int shared_mem = K * int(sizeof(typename TensorLSE::value_type)) + 16; + cudaError_t result; + if (shared_mem >= (48 << 10)) { + result = cudaFuncSetAttribute( + &fmha_mla_reference_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + throw std::runtime_error("couldn't perform smem optin"); + } + } + fmha_mla_reference_kernel<<>>( + problem_shape, mSeq, mPT, mQL, mQR, mCL, mKR, mO, mLSE, scale); + cudaDeviceSynchronize(); + result = cudaGetLastError(); + if (cudaSuccess != result) { + throw std::runtime_error("couldn't execute reference"); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/reference_abs_error.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/reference_abs_error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8c29289c989931970fafd507a4e42f0502cbd2eb --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/reference_abs_error.hpp @@ -0,0 +1,282 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct DeviceAllocation { + T* ptr_ = nullptr; + size_t offset_ = 0; + size_t size_ = 0; + + DeviceAllocation(DeviceAllocation const&) = delete; + DeviceAllocation& operator=(DeviceAllocation const&) = delete; + + DeviceAllocation() = default; + DeviceAllocation(size_t size) { reset(size); } + ~DeviceAllocation() { reset(); } + + void reset(size_t size, size_t offset=0) { + reset(); + auto ret = cudaMalloc(&ptr_, sizeof(T) * (size + offset)); + assert(ret == cudaSuccess); + size_ = size; + offset_ = offset; + } + + T* get() { + return ptr_ + offset_; + } + + const T* get() const { + return ptr_ + offset_; + } + + void reset() { + if (ptr_ != nullptr) { + auto ret = cudaFree(ptr_); + assert(ret == cudaSuccess); + } + } + + size_t size() const { return size_; } + + size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } + + void copy_from_host(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } + + void copy_from_device(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } +}; + +template +__global__ void reference_abs_diff_kernel( + Element* data, Element* data_ref, size_t count, + double* max_diff, double* sum_diff, + bool print_diff ) { + + double thread_max_diff = 0; + double thread_sum_diff = 0; + + __shared__ double block_max_diff; + __shared__ double block_sum_diff; + + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + if (data[i] == data_ref[i]) { + continue; + } + + double diff = fabs(data[i] - data_ref[i]); + if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast(i), diff, (double)data[i], (double)data_ref[i]); + thread_max_diff = fmax(diff, thread_max_diff); + thread_sum_diff += diff; + } + + for (int i = 0; i < blockDim.x; i++) { + if (i == threadIdx.x) { + if (i == 0) { + block_max_diff = thread_max_diff; + block_sum_diff = thread_sum_diff; + } + else { + block_max_diff = fmax(block_max_diff, thread_max_diff); + block_sum_diff += thread_sum_diff; + } + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicAdd(sum_diff, block_sum_diff); + + for (;;) { + unsigned long long prev = *reinterpret_cast(max_diff); + double prev_diff = reinterpret_cast(prev); + double new_max_diff = fmax(block_max_diff, prev_diff); + unsigned long long found = atomicCAS(reinterpret_cast(max_diff), prev, reinterpret_cast(new_max_diff)); + if (found == prev) break; + } + } +} + +template +void reference_abs_diff( + DeviceAllocation const& data, + DeviceAllocation const& data_ref, + double& max_diff, double& mean_diff) { + + static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1; + + DeviceAllocation result; + result.reset(2); + assert(data.size() == data_ref.size()); + + cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double)); + if (err != cudaSuccess) { + std::cerr << "Memset failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + dim3 block(256, 1, 1); + dim3 grid(1024, 1, 1); + reference_abs_diff_kernel<<>>( + data.get(), data_ref.get(), data.size(), + result.get(), result.get() + 1, kPrintDiff); + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "Difference kernel failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + double result_host[2]; + err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault); + if (err != cudaSuccess) { + std::cerr << "Copy failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + max_diff = result_host[0]; + mean_diff = result_host[1] / static_cast(data.size()); +} + +template +__global__ void reference_rel_diff_kernel( + Element* data, Element* data_ref, size_t count, + double* max_diff, double* sum_diff, + bool print_diff ) { + + double thread_max_diff = 0; + double thread_sum_diff = 0; + + __shared__ double block_max_diff; + __shared__ double block_sum_diff; + + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + if (data[i] == data_ref[i]) { + continue; + } + double diff = fabs(data[i] - data_ref[i]) / fabs(data_ref[i]); + if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast(i), diff, (double)data[i], (double)data_ref[i]); + thread_max_diff = fmax(diff, thread_max_diff); + thread_sum_diff += diff; + } + + for (int i = 0; i < blockDim.x; i++) { + if (i == threadIdx.x) { + if (i == 0) { + block_max_diff = thread_max_diff; + block_sum_diff = thread_sum_diff; + } + else { + block_max_diff = fmax(block_max_diff, thread_max_diff); + block_sum_diff += thread_sum_diff; + } + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicAdd(sum_diff, block_sum_diff); + + for (;;) { + unsigned long long prev = *reinterpret_cast(max_diff); + double prev_diff = reinterpret_cast(prev); + double new_max_diff = fmax(block_max_diff, prev_diff); + unsigned long long found = atomicCAS(reinterpret_cast(max_diff), prev, reinterpret_cast(new_max_diff)); + if (found == prev) break; + } + } +} + +template +void reference_rel_diff( + DeviceAllocation const& data, + DeviceAllocation const& data_ref, + double& max_diff, double& mean_diff) { + + static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1; + + DeviceAllocation result; + result.reset(2); + assert(data.size() == data_ref.size()); + + cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double)); + if (err != cudaSuccess) { + std::cerr << "Memset failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + dim3 block(256, 1, 1); + dim3 grid(1024, 1, 1); + reference_rel_diff_kernel<<>>( + data.get(), data_ref.get(), data.size(), + result.get(), result.get() + 1, kPrintDiff); + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "Difference kernel failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + double result_host[2]; + err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault); + if (err != cudaSuccess) { + std::cerr << "Copy failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + max_diff = result_host[0]; + mean_diff = result_host[1] / static_cast(data.size()); +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f26e6be8246e4f0318d6ce4973beac1e50dafa5b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cute/tensor.hpp" + +#include +#include +#include "helper.h" + +enum MixedDtypeGemmMode { + ConvertOnly, + ScaleOnly, + ScaleWithZeroPoint +}; + +/// Command line options parsing +struct MixedDtypeOptions { + + bool help = false; + bool verify = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 1000; + int warmup = 1000; + int mode = 1; + int m = 5120, n = 4096, k = 4096; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("verify")) { + verify = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("mode", mode); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("warmup", warmup); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "86_blackwell_mixed_dtype_gemm\n\n" + << " Blackwell Mixed Data Type GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --mode= The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --warmup= Number of warmup iterations to perform.\n\n" + << " --verify= Run verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "86_blackwell_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct MixedDtypeResult +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +/// Profiling Loop +template +void mixed_dtype_profiling( + Gemm& gemm, + MixedDtypeOptions const& options, + MixedDtypeResult& result) { + + if (options.iterations <= 0) return; + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + std::vector runtimes; + runtimes.reserve(options.iterations); + + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } + } + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + // Compute average setup and runtime and GFLOPs. + result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + +} + +/// Helpers to initialize a block of device data +template +bool initialize_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_quant_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + + float scope_min = float(cutlass::platform::numeric_limits::lowest()); + float scope_max = float(cutlass::platform::numeric_limits::max()); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_scale( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + // No scales, so just initialize with 1 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(1.0f)); + block.copy_from_host(stage.data()); + } + else { + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + const float max_dequant_val = 4.f; + const float min_dequant_val = 0.5f; + + float scope_max(max_dequant_val / elt_max_f); + float scope_min(min_dequant_val / elt_max_f); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + } + return true; +} + +template +bool initialize_zero( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(2.0f), Element(-2.0f)); + } else { + // No bias, so just initialize with 1 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); + } + return true; +} + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/87_blackwell_geforce_gemm_blockwise/utils.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/87_blackwell_geforce_gemm_blockwise/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..72735303883bba3d185e8d56df06273f33938ad5 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/87_blackwell_geforce_gemm_blockwise/utils.h @@ -0,0 +1,83 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_input == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_bwd_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_bwd_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d9d311a54e61ede4407b2ab9320b85f73d15ef8f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_bwd_tma_warpspecialized.hpp @@ -0,0 +1,863 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_collective_load.hpp" +#include "../collective/fmha_collective_softmax.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::collective { + +template< + typename Element_, + typename ElementAccumulator_, + typename TileShape_, // BlockQO, BlockKV, BlockHead + class Fusion, + class... Options +> +struct FmhaBwdMainloopTmaWarpSpecialized { + + using Element = Element_; + using ElementAccumulator = ElementAccumulator_; + using TileShape = TileShape_; + + static constexpr bool kIsPersistent = false; + + static const int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = 2; + static constexpr int StageCountQ = 2 /*K, V*/ * NumMmaWarpGroups; + static constexpr int StageCount = 2 /*Q, dO*/ * 2 /* actual stages */; + + static const int kOuterLoads = 2; + using StagesQ = cutlass::gemm::collective::StageCount; + using Stages = cutlass::gemm::collective::StageCount; + using ClusterShape = Shape<_1, _1, _1>; + static_assert(StagesQ::value >= 2); + static_assert(Stages::value >= 2 * NumMmaWarpGroups); + + // 16B alignment lets us use TMA + static constexpr int Alignment = 16 / sizeof(Element); + + using TileShapeNM = Shape< // (N,M,D) + decltype(tuple_element_t<1, TileShape>{} / Int{}), + tuple_element_t<0, TileShape>, + tuple_element_t<2, TileShape>>; + + using TileShapeND = decltype(select<0,2,1>(TileShapeNM{})); // (N,D,M) + + using TileShapeMD = decltype(select<2,1,0>(TileShapeND{})); // (M,D,N) + + using CollectiveMmaNM = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple>, Alignment, + Element, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeNM, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaND = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple>, Alignment, // from register, doesn't matter + Element, cute::tuple<_1, int, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeND, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaND_SS = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple>, Alignment, // from register, doesn't matter + Element, cute::tuple<_1, int, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeND, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + + using CollectiveMmaMD = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple<_1, int, cute::tuple>, Alignment, // from smem, might matter (?) + Element, cute::tuple<_1, int, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeMD, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaNM = typename CollectiveMmaNM::TiledMma; + using TiledMmaND_SS = typename CollectiveMmaND_SS::TiledMma; + using TiledMmaND_RS = decltype(convert_to_gmma_rs(typename CollectiveMmaND::TiledMma{})); + using TiledMmaND = TiledMmaND_RS; + using TiledMmaMD = typename CollectiveMmaMD::TiledMma; + + using SmemLayoutQ = typename CollectiveMmaNM::SmemLayoutB; + using SmemLayoutK = typename CollectiveMmaNM::SmemLayoutA; + using SmemLayoutV = typename CollectiveMmaNM::SmemLayoutA; + using SmemLayoutDO = typename CollectiveMmaNM::SmemLayoutB; + + //using SmemLayoutDQ = Layout< + // Shape< + // tuple_element_t<0, TileShapeMD>, + // Shape<_2, _4, decltype(tuple_element_t<1, TileShapeMD>{} / _8{})>, + // _2 + // >, + // Stride< + // _4, + // Stride{} * _4{}), _1, decltype(tuple_element_t<0, TileShapeMD>{} * _8{})>, + // decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{}) + // >>; + + using SmemLayoutDQ_0 = Layout< + Shape< + tuple_element_t<0, TileShapeMD>, + tuple_element_t<1, TileShapeMD>, + _2 + >, + Stride< + tuple_element_t<1, TileShapeMD>, + _1, + decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{}) + >>; + + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::K, ElementAccumulator, tuple_element_t<0, TileShapeMD>, tuple_element_t<1, TileShapeMD>>()); + using SmemLayoutDQ_1 = decltype(tile_to_shape(SmemAtomDQ{}, make_shape(get<0>(TileShapeMD{}), get<1>(TileShapeMD{}), _2{}), Step<_2, _1, _3>{})); + using SmemLayoutDQ = SmemLayoutDQ_1; + + + using PipelineDQ = cutlass::PipelineAsync<2>; + + + using SmemLayoutDS_0 = decltype(unstageSmemLayout(typename CollectiveMmaMD::SmemLayoutA{}, Int{})); + + using SmemLayoutDS = decltype(tile_to_shape(GMMA::Layout_MN_INTER_Atom{}, make_shape(size<0>(SmemLayoutDS_0{}), size<1>(SmemLayoutDS_0{}), size<2>(SmemLayoutDS_0{})), Step<_1, _2, _3>{})); + using SmemLayoutKp = typename CollectiveMmaMD::SmemLayoutB; + + using SmemLayoutQp = typename CollectiveMmaND::SmemLayoutB; + using SmemLayoutDOp = typename CollectiveMmaND::SmemLayoutB; + + using SmemLayoutLSE = Layout, Int>>; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using MainloopPipelineQ = cutlass::PipelineTmaAsync; + + using PipelineState = typename cutlass::PipelineState; + using PipelineStateQ = typename cutlass::PipelineState; + + using TileShapePV = TileShapeND; // To work with the kernel level + using TiledMmaPV = TiledMmaND; + + static constexpr int kInnerLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element) + size(SmemLayoutLSE{}(_,_0{})) * sizeof(ElementAccumulator); + static constexpr int kOuterLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); + + struct SharedStorage { + // One for each consumer WG + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_kp; + cute::array_aligned> smem_v; + }; + + cute::array_aligned> smem_ds; + + // Loaded by producer, consumed by both WGs + union { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + cute::array_aligned> smem_qp; + cute::array_aligned> smem_dop; + }; + + // Accumulated into by both consumers, potentially loaded, potentially written + cute::array_aligned> smem_dq; + + union { + cute::array_aligned> smem_lse; + cute::array_aligned> smem_sumOdO; + }; + }; + + struct Arguments { + const Element* ptr_Q; + cute::tuple dQ; + const Element* ptr_K; + cute::tuple dK; + const Element* ptr_V; + cute::tuple dV; + + const Element* ptr_dO; + cute::tuple dDO; + + const ElementAccumulator* ptr_LSE; + cute::tuple dLSE; + const ElementAccumulator* ptr_sum_OdO; + cute::tuple dSumOdO; + + ElementAccumulator* ptr_dQ; + cute::tuple dDQ; + }; + + using TMA_Q = typename CollectiveMmaNM::Params::TMA_B; + using TMA_K = typename CollectiveMmaNM::Params::TMA_A; + using TMA_V = typename CollectiveMmaNM::Params::TMA_A; + using TMA_DO = typename CollectiveMmaNM::Params::TMA_B; + + using TMA_LSE = decltype(make_tma_copy(SM90_TMA_LOAD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1), make_stride(_1{}, 0, 0)), SmemLayoutLSE{}(_,_0{}))); + using TMA_ODO = TMA_LSE; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1, 1), make_stride(0, _1{}, 0, 0)), SmemLayoutDQ{}(_,_,_0{}))); + + using LoadQ = CollectiveLoadTma< + LoadKind::kBwdM, + MainloopPipeline, + Element, + SmemLayoutQ, + TMA_Q + >; + + using LoadK = CollectiveLoadTma< + LoadKind::kBwdN, + MainloopPipelineQ, + Element, + SmemLayoutK, + TMA_K + >; + + using LoadV = CollectiveLoadTma< + LoadKind::kBwdN, + MainloopPipelineQ, + Element, + SmemLayoutV, + TMA_V + >; + + using LoadDO = CollectiveLoadTma< + LoadKind::kBwdM, + MainloopPipeline, + Element, + SmemLayoutDO, + TMA_DO + >; + + using LoadLSE = CollectiveLoadTma< + LoadKind::kBwdScalar, + MainloopPipeline, + ElementAccumulator, + SmemLayoutLSE, + TMA_LSE + >; + + using LoadODO = CollectiveLoadTma< + LoadKind::kBwdScalar, + MainloopPipeline, + ElementAccumulator, + SmemLayoutLSE, + TMA_ODO + >; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_DO tma_load_do; + + TMA_LSE tma_load_lse; + TMA_ODO tma_load_odo; + + TMA_DQ tma_red_dq; + + float scale_softmax; + float scale_softmax_log2; + }; + + static_assert(size(TiledMmaNM{}) == size(TiledMmaND{})); + static_assert(size(TiledMmaNM{}) == size(TiledMmaMD{})); + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + return true + && (get<4>(problem_size) <= get<2>(TileShape{})) + && ((get<4>(problem_size) % Alignment) == 0) + && ((get<2>(problem_size) % Alignment) == 0) + ; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) { + auto problem_shape_nm = make_shape(get<3>(problem_size), get<2>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size))); + + auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), make_stride(get<0>(args.dK), get<1>(args.dK))); + auto dQ = make_stride(get<2>(args.dQ), get<3>(args.dQ), make_stride(get<0>(args.dQ), get<1>(args.dQ))); + auto params_nm_kq = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm, + typename CollectiveMmaNM::Arguments { + args.ptr_K, dK, + args.ptr_Q, dQ, + }, /*workspace=*/ nullptr); + + auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), make_stride(get<0>(args.dV), get<1>(args.dV))); + auto dDO = make_stride(get<2>(args.dDO), get<3>(args.dDO), make_stride(get<0>(args.dDO), get<1>(args.dDO))); + auto params_nm_vdo = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm, + typename CollectiveMmaNM::Arguments { + args.ptr_V, dV, + args.ptr_dO, dDO, + }, /*workspace=*/ nullptr); + + + TMA_LSE tma_load_lse = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_LSE, select<2,0,1>(problem_size), select<2,0,1>(args.dLSE)), SmemLayoutLSE{}(_,_0{})); + TMA_ODO tma_load_odo = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_sum_OdO, select<2,0,1>(problem_size), select<2,0,1>(args.dSumOdO)), SmemLayoutLSE{}(_,_0{})); + + TMA_DQ tma_red_dq = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(args.ptr_dQ, select<2,4,0,1>(problem_size), select<2,3,0,1>(args.dDQ)), SmemLayoutDQ{}(_,_,_0{})); + + return Params{ + params_nm_kq.tma_load_b, + params_nm_kq.tma_load_a, + params_nm_vdo.tma_load_a, + params_nm_vdo.tma_load_b, + tma_load_lse, tma_load_odo, + tma_red_dq, + 1.0f / (float) std::sqrt(get<4>(problem_size)), + (float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))) + }; + } + + template + CUTLASS_DEVICE + auto + get_inner_tile_count(BlkCoord const& blk_coord, ProblemSize const& problem_size) { + return Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size); + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_do.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_odo.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_lse.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load_kv_maybe_q( + int block_rank_in_cluster, + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_write_inner, + MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + // Load pattern: + // K0 V0 K1 V1 + // Q0 DO0 Q1 DO1 Q2 DO2 ... + // K0 Q0 V0 K1 DO0 V1 ... + int lane_predicate = cute::elect_one_sync(); + + int outer_tile_count = NumMmaWarpGroups; + int inner_tile_count = get_inner_tile_count(blk_coord, problem_size); + + auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count); + auto inner_tile_iter = cute::make_coord_iterator(inner_tile_count); + + uint16_t mcast_mask_b = 0; + + LoadQ load_q{params.tma_load_q, pipeline_inner, storage.smem_q}; + auto load_state_q = load_q.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count); + + LoadDO load_do{params.tma_load_do, pipeline_inner, storage.smem_do}; + auto load_state_do = load_do.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count); + + LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k}; + auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v}; + auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadLSE load_lse{params.tma_load_lse, pipeline_inner, storage.smem_lse}; + auto load_state_lse = load_lse.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadODO load_odo{params.tma_load_odo, pipeline_inner, storage.smem_sumOdO}; + auto load_state_odo = load_odo.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + outer_tile_count *= 2; // K & V + inner_tile_count *= 4; // Q & dO & LSE & sumOdO + + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 4; + ++inner_tile_iter; + } + + if constexpr (kLoadOuter) { + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + + load_q.template step(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_lse.template step(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + + if constexpr (! kLoadOuter) { + if (do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + } + + if constexpr (kLoadOuter) { + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + + load_do.template step(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_odo.template step(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + + if constexpr (kLoadOuter) { + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + + if constexpr (kLoadOuter) { + while (outer_tile_count > 0) { + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + } + + CUTLASS_PRAGMA_NO_UNROLL + while (inner_tile_count > 0) { + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 4; + ++inner_tile_iter; + } + load_q.template step(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_lse.template step(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + + load_do.template step(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_odo.template step(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + } + } + + template + CUTLASS_DEVICE void + load_maybe_q( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + // Load pattern: + // K0 V0 K1 V1 + // Q0 DO0 Q1 DO1 Q2 DO2 ... + // K0 Q0 V0 K1 DO0 V1 ... + int lane_predicate = cute::elect_one_sync(); + + int outer_tile_count = NumMmaWarpGroups; + + auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count); + + LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k}; + auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v}; + auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + outer_tile_count *= 2; // K & V + + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + + if (do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + + while (outer_tile_count > 0) { + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + } + + template + CUTLASS_DEVICE void + reduce( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_read_reducer, + SharedStorage& storage) + { + int lane_predicate = cute::elect_one_sync(); + + Tensor mDQ_full = params.tma_red_dq.get_tma_tensor(select<2,4,0,1>(problem_size)); + Tensor gDQ_full = local_tile(mDQ_full, TileShapeMD{}, make_coord(_, _, _), Step<_1, _1, Underscore>{}); + Tensor gDQ = gDQ_full(_, _, _, _0{}, get<2,0>(blk_coord), get<2,1>(blk_coord)); + Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{}); + + auto block_tma = params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int inner_tile_count = get_inner_tile_count(blk_coord, problem_size); + int g_index = 0; + + auto smem_pipe_release_reducer = smem_pipe_read_reducer; + bool first = true; + while (inner_tile_count > 0) { + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(g_index, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 1; + ++g_index; + } + if (inner_tile_count == 0) break; + + pipeline_reducer.consumer_wait(smem_pipe_read_reducer); + if (lane_predicate == 1) { + tma_store_wait<1>(); + } + if (! first) { + pipeline_reducer.consumer_release(smem_pipe_release_reducer); + ++smem_pipe_release_reducer; + } else { + first = false; + } + if (lane_predicate == 1) { + copy(params.tma_red_dq, tDQsDQ(_,_,_,smem_pipe_read_reducer.index()), tDQgDQ(_,_,_,g_index)); + tma_store_arrive(); + } + ++smem_pipe_read_reducer; + --inner_tile_count; + ++g_index; + } + if (lane_predicate) { + tma_store_wait<0>(); + } + pipeline_reducer.consumer_release(smem_pipe_release_reducer); + ++smem_pipe_release_reducer; + } + + template + CUTLASS_DEVICE auto + compute( + BlkCoord const& blk_coord, BlkCoord const& wg_coord, + Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_read_inner, + MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_read_outer, + MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer, + SharedStorage& storage, + MathWgOrderBarrier& math_wg_order_barrier) + { + TiledMmaND tiled_mma_nd; + + Tensor acc_DV = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{})); + clear(acc_DV); + + Tensor acc_DK = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{})); + clear(acc_DK); + + int thread_idx = int(threadIdx.x) % cutlass::NumThreadsPerWarpGroup; + + PipelineState smem_pipe_release_inner = smem_pipe_read_inner; + + pipeline_outer.consumer_wait(smem_pipe_read_outer); + PipelineStateQ smem_pipe_read_k = smem_pipe_read_outer; + ++smem_pipe_read_outer; + pipeline_outer.consumer_wait(smem_pipe_read_outer); + PipelineStateQ smem_pipe_read_v = smem_pipe_read_outer; + + int inner_tile_count = get_inner_tile_count(wg_coord, problem_size); + + TiledMmaNM tiled_mma_nm; + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto thr_mma_nm = tiled_mma_nm.get_thread_slice(thread_idx); + Tensor tSsK = thr_mma_nm.partition_A(sK); + Tensor tSsQ = thr_mma_nm.partition_B(sQ); + Tensor tSrK = thr_mma_nm.make_fragment_A(tSsK); + Tensor tSrQ = thr_mma_nm.make_fragment_B(tSsQ); + + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + Tensor sDO = make_tensor(make_smem_ptr(storage.smem_do.data()), SmemLayoutDO{}); + + Tensor tDPsV = thr_mma_nm.partition_A(sV); + Tensor tDPsDO = thr_mma_nm.partition_B(sDO); + Tensor tDPrV = thr_mma_nm.make_fragment_A(tDPsV); + Tensor tDPrDO = thr_mma_nm.make_fragment_B(tDPsDO); + + auto thr_mma_nd = tiled_mma_nd.get_thread_slice(thread_idx); + + Tensor sDOp = make_tensor(make_smem_ptr(storage.smem_dop.data()), SmemLayoutDOp{}); + Tensor tDV_sDO = thr_mma_nd.partition_B(sDOp); + Tensor tDVrDO = thr_mma_nd.make_fragment_B(tDV_sDO); + + + Tensor sQp = make_tensor(make_smem_ptr(storage.smem_qp.data()), SmemLayoutQp{}); + Tensor tDK_sQ = thr_mma_nd.partition_B(sQp); + Tensor tDKrQ = thr_mma_nd.make_fragment_B(tDK_sQ); + + + int wg_idx = __shfl_sync(0xffffffff, get<1>(wg_coord) % NumMmaWarpGroups, 0); + + TiledMmaMD tiled_mma_md; + auto thr_mma_md = tiled_mma_md.get_thread_slice(thread_idx); + Tensor sDS = make_tensor(make_smem_ptr(storage.smem_ds.data()), SmemLayoutDS{}); + Tensor tDQsDS = thr_mma_md.partition_A(sDS); + Tensor tDQrDS_full = thr_mma_md.make_fragment_A(tDQsDS); + Tensor tDQrDS = tDQrDS_full(_,_,_,_); + Tensor sKp = make_tensor(make_smem_ptr(storage.smem_kp.data()), SmemLayoutKp{}); + Tensor tDQsK = thr_mma_md.partition_B(sKp); + Tensor tDQrK = thr_mma_md.make_fragment_B(tDQsK); + + Tensor sLSE = make_tensor(make_smem_ptr(storage.smem_lse.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{}))); + Tensor tSsLSE = thr_mma_nm.partition_C(sLSE); + + Tensor sODO = make_tensor(make_smem_ptr(storage.smem_sumOdO.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{}))); + Tensor tDPsODO = thr_mma_nm.partition_C(sODO); + + Tensor cS = make_identity_tensor(take<0,2>(TileShapeNM{})); + Tensor tScS = thr_mma_nm.partition_C(cS); + int n_block = get<1>(wg_coord); + tScS.data() = tScS.data() + E<0>{} * n_block * get<0>(TileShapeNM{}); + + + // Transpose + Tensor sDSp_full = sDS.compose(make_layout(make_shape(size<1>(sDS), size<0>(sDS), size<2>(sDS)), make_stride(size<0>(sDS), _1{}, size<1>(sDS) * size<0>(sDS)))); + Tensor sDSp = sDSp_full(_,_,_); + Tensor tDPsDS = thr_mma_nm.partition_C(sDSp); + + auto thr_mma_nd_ss = TiledMmaND_SS{}.get_thread_slice(thread_idx); + Tensor tDKsDSp = thr_mma_nd_ss.partition_A(sDSp); + + Tensor tDKrDSp = thr_mma_nd_ss.make_fragment_A(tDKsDSp); + + Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{}); + auto tDQsDQ_full = thr_mma_md.partition_C(sDQ); + + + auto smem_pipe_read_k_other = smem_pipe_read_k; + smem_pipe_read_k_other.advance(2); + + int k_index = 0; + + while (inner_tile_count > 0) { + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(k_index, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 1; + tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{}); + k_index += 1; + } + if (inner_tile_count == 0) break; + + pipeline_inner.consumer_wait(smem_pipe_read_inner); + PipelineState smem_pipe_read_q = smem_pipe_read_inner; + ++smem_pipe_read_inner; + PipelineState smem_pipe_read_do = smem_pipe_read_inner; + ++smem_pipe_read_inner; + + // GEMM KQ -> S + Tensor acc_S = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{})); + + warpgroup_fence_operand(acc_S); + warpgroup_arrive(); + gemm_zero_acc(tiled_mma_nm, tSrK(_,_,_,smem_pipe_read_k.index()), tSrQ(_,_,_,smem_pipe_read_q.index()), acc_S); + warpgroup_commit_batch(); + + pipeline_inner.consumer_wait(smem_pipe_read_do); + + // GEMM VdO -> dP + Tensor acc_DP = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{})); + + warpgroup_fence_operand(acc_DP); + warpgroup_arrive(); + gemm_zero_acc(tiled_mma_nm, tDPrV(_,_,_,smem_pipe_read_v.index()), tDPrDO(_,_,_,smem_pipe_read_do.index()), acc_DP); + warpgroup_commit_batch(); + + Tensor reg_LSE = make_fragment_like(acc_S); + for (int i = 0; i < size(reg_LSE); i++) { + reg_LSE(i) = ((ElementAccumulator)std::log2(std::exp(1.0))) * tSsLSE(_,_,_,smem_pipe_read_q.index())(i); + } + + Tensor reg_ODO = make_fragment_like(acc_S); + if constexpr (decltype(get<0>(TileShape{}) != _128{})::value) { + for (int i = 0; i < size(reg_ODO); i++) { + reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i); + } + } + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_S); + + math_wg_order_barrier.wait(); + // Compute S -> P + Fusion{}.before_softmax(acc_S, tScS, problem_size); + auto acc_P = make_fragment_like(acc_S); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_P); i++) { + acc_P(i) = ::exp2f(params.scale_softmax_log2 * acc_S(i) - reg_LSE(i)); + } + math_wg_order_barrier.arrive(); + + if constexpr (decltype(get<0>(TileShape{}) == _128{})::value) { + for (int i = 0; i < size(reg_ODO); i++) { + reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i); + } + } + + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_DP); + + // Compute dP P -> dS + auto acc_DS = make_fragment_like(acc_DP); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_DS); i++) { + // We could move the scale out and into the respective epilogues (or a final scaling step) + acc_DS(i) = acc_P(i) * params.scale_softmax * (acc_DP(i) - reg_ODO(i)); + } + + // GEMM PdO -> dV + auto op_P = make_acc_into_op(acc_P, typename TiledMmaND::LayoutA_TV{}); + warpgroup_fence_operand(acc_DV); + warpgroup_fence_operand(op_P); + warpgroup_arrive(); + cute::gemm(tiled_mma_nd, op_P, tDVrDO(_,_,_,smem_pipe_read_do.index()), acc_DV); + warpgroup_commit_batch(); + + // Store dS to smem dS' + if (wg_idx == 0) math_wg_order_barrier.wait(); + + auto recast_bits = [](auto sz, auto t) { + return recast>(t); + }; + auto tDPsDS_v = recast_bits(Int * 2>{}, tDPsDS); + auto acc_DS_v = recast_bits(Int * 2>{}, acc_DS); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_DS_v); i++) { + tDPsDS_v(_,_,_,wg_idx)(i) = acc_DS_v(i); + } + + cutlass::arch::fence_view_async_shared(); + if (wg_idx == 0) math_wg_order_barrier.arrive(); + + // GEMM dS Q -> dK + if (wg_idx == 1) { + + math_wg_order_barrier.wait(); + + // GEMM dS' K -> dQ + Tensor acc_DQ = partition_fragment_C(tiled_mma_md, take<0,2>(TileShapeMD{})); + + warpgroup_fence_operand(acc_DQ); + warpgroup_arrive(); + gemm_zero_acc(tiled_mma_md, tDQrDS(_,_,_,0), tDQrK(_,_,_,smem_pipe_read_k_other.index()), acc_DQ); + cute::gemm(tiled_mma_md, tDQrDS(_,_,_,1), tDQrK(_,_,_,smem_pipe_read_k.index()), acc_DQ); + warpgroup_commit_batch(); + + warpgroup_fence_operand(acc_DK); + warpgroup_arrive(); + cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK); + warpgroup_commit_batch(); + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_DK); + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_DQ); + + math_wg_order_barrier.arrive(); + + pipeline_reducer.producer_acquire(smem_pipe_write_reducer); + auto tDQsDQ = tDQsDQ_full(_,_,_,smem_pipe_write_reducer.index()); + + // Store dQ to smem dQ' + // Invoke TMA reduce on dQ' + using Vec = uint_bit_t * 2>; + auto tDQsDQ_v = recast(tDQsDQ); + auto acc_DQ_v = recast(acc_DQ); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_DQ_v); i++) { + tDQsDQ_v(i) = acc_DQ_v(i); + } + + cutlass::arch::fence_view_async_shared(); + + pipeline_reducer.producer_commit(smem_pipe_write_reducer); + ++smem_pipe_write_reducer; + } else { + + warpgroup_fence_operand(acc_DK); + warpgroup_arrive(); + cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK); + warpgroup_commit_batch(); + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_DK); + + pipeline_reducer.producer_acquire(smem_pipe_write_reducer); + pipeline_reducer.producer_commit(smem_pipe_write_reducer); + ++smem_pipe_write_reducer; + } + + --inner_tile_count; + + pipeline_inner.consumer_release(smem_pipe_release_inner); + ++smem_pipe_release_inner; + pipeline_inner.consumer_release(smem_pipe_release_inner); + ++smem_pipe_release_inner; + + tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{}); + k_index += 1; + } + + pipeline_outer.consumer_release(smem_pipe_read_k); + pipeline_outer.consumer_release(smem_pipe_read_outer); + pipeline_reducer.producer_tail(smem_pipe_write_reducer); + ++smem_pipe_read_outer; + + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_DK); + warpgroup_fence_operand(acc_DV); + + return make_tuple(acc_DK, acc_DV); + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_load.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55029faffbf5585ec8b41cd07781dc84ea6b7d6d --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_load.hpp @@ -0,0 +1,140 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +enum class LoadKind { + kQ, kK, kV, + kBwdN, kBwdM, kBwdScalar +}; + +template< + LoadKind kKind, + class Pipeline, + class Element, + class SmemLayout, + class TMA +> +struct CollectiveLoadTma { + + using Params = TMA; + using SharedStorage = cute::array_aligned>; + using PipelineState = typename cutlass::PipelineState; + + Params const& params; + Pipeline& pipeline; + SharedStorage& storage; + + CUTLASS_DEVICE + CollectiveLoadTma(Params const& params, Pipeline& pipeline, SharedStorage& storage) + : params(params), pipeline(pipeline), storage(storage) {} + + template + CUTLASS_DEVICE auto init_g(ProblemSize const& problem_size, TileShape const& tile_shape, + BlockCoord const& blk_coord, int loop_count + ) { + using X = Underscore; + if constexpr (kKind == LoadKind::kK) { + Tensor mK_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor gK_full = local_tile(mK_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor gK = gK_full(_, _, _, _0{}, get<2>(blk_coord)); + return gK; + } else if constexpr (kKind == LoadKind::kQ) { + Tensor mQ_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor gQ_full = local_tile(mQ_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gQ = gQ_full(_, _, _, _0{}, get<2>(blk_coord)); + return make_tensor(gQ.data() + loop_count * get<0>(blk_coord) * stride<2>(gQ), gQ.layout()); + } else if constexpr (kKind == LoadKind::kV) { + Tensor mV_full = params.get_tma_tensor(make_shape(get<4>(problem_size), get<3>(problem_size), select<0,1>(problem_size))); + Tensor gV_full = local_tile(mV_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor gV = gV_full(_, _, _0{}, _, get<2>(blk_coord)); + return gV; + } else if constexpr (kKind == LoadKind::kBwdN) { + Tensor m_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord)); + return make_tensor(g.data() + loop_count * get<1>(blk_coord) * stride<2>(g), g.layout()); + } else if constexpr (kKind == LoadKind::kBwdM) { + Tensor m_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord)); + return g; + } else if constexpr (kKind == LoadKind::kBwdScalar) { + Tensor m_full = params.get_tma_tensor(select<2,0,1>(problem_size)); + Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor g = g_full(_, _, get<2,0>(blk_coord), get<2,1>(blk_coord)); + return g; + } + } + + template + CUTLASS_DEVICE auto init_state(ClusterRank const& block_rank_in_cluster, + ProblemSize const& problem_size, TileShape const& tile_shape, + BlockCoord const& block_coord, int loop_count + ) { + Tensor g = init_g(problem_size, tile_shape, block_coord, loop_count); + Tensor s = make_tensor(make_smem_ptr(storage.data()), SmemLayout{}); + + auto block_tma = params.get_slice(block_rank_in_cluster); + Tensor ts = block_tma.partition_D(s); + Tensor tg = block_tma.partition_S(g); + + return make_tuple(tg, ts); + } + + template + CUTLASS_DEVICE void step(TileIterator& tile_iter, State const& state, + PipelineState& smem_pipe_write, + int lane_predicate, int& tile_count, uint16_t mcast_mask = 0 + ) { + if ((lane_predicate == 1) && (tile_count > 0)) { + if constexpr (kAcquireBarrier) pipeline.producer_acquire(smem_pipe_write); + using BarrierType = typename Pipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + if constexpr (kKind == LoadKind::kBwdScalar) { + copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,*tile_iter), get<1>(state)(_,_,smem_pipe_write.index())); + } else { + copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,_,*tile_iter), get<1>(state)(_,_,_,smem_pipe_write.index())); + } + if constexpr (kAdvancePipe) ++smem_pipe_write; + if constexpr (kAdvanceIterator) ++tile_iter; + } + --tile_count; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9c958da082bc69faf04bb1db00aa48fc93353a2c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +#include "../collective/fmha_common.hpp" + +namespace cutlass::fmha::collective { + +template< + class ElementAccumulator, + class Fusion, + class Params +> +struct CollectiveSoftmax { + Params const& params; + CUTLASS_DEVICE CollectiveSoftmax(Params const& params) : params(params) {} + + using SumType = float; + using MaxType = ElementAccumulator; + + template + CUTLASS_DEVICE auto init(AccPV const& acc_pv, TiledMmaPV const& tiled_mma_pv) { + Tensor s_max = make_fragment_like(size<0>(layout_acc_mn(tiled_mma_pv, acc_pv.layout()))); + Tensor a_sum = make_fragment_like(s_max); + return make_tuple(s_max, a_sum); + } + + CUTLASS_DEVICE float overload_exp2(float f) { + return ::exp2f(f); + } + + CUTLASS_DEVICE cutlass::half_t overload_exp2(cutlass::half_t f) { + auto a = f.raw(); + decltype(a) d; + asm("ex2.approx.f16 %0, %1;" : "=h"(d) : "h"(a)); + return cutlass::half_t::bitcast(d); + } + + + CUTLASS_DEVICE float overload_max(float a, float b) { + return ::max(a, b); + } + + CUTLASS_DEVICE cutlass::half_t overload_max(cutlass::half_t a, cutlass::half_t b) { + return cutlass::half_t{__hmax_nan(a.to_half(), b.to_half())}; + } + + CUTLASS_DEVICE half overload_to_native(cutlass::half_t f) { + return f.to_half(); + } + + CUTLASS_DEVICE float overload_to_native(float f) { + return f; + } + + template + CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, ProblemShape const& problem_shape) { + Fusion{}.before_softmax(acc_qk, count_qk, problem_shape); + Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout())); + auto reduction_target_qk = reduction_target_n(tiled_mma_qk); + constexpr int red_rank = decltype(rank(reduction_target_qk))::value; + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + // Linear reduction is faster for the first iteration + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = acc_qk_mn(i, 0); + } + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < size<1>(acc_qk_mn); j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j)); + } + } + + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target_qk); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride(reduction_target_qk) * j)}); + } + } + }); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + MaxType local_max = s_max(i) == static_cast(-INFINITY) ? static_cast(0) : s_max(i); + MaxType scale = static_cast(params.scale_softmax_log2); + MaxType scale_max = scale * local_max; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + a_sum(i) = SumType{reduce(acc_qk_mn(i, _), cute::plus{})}; + } + } + + template + CUTLASS_DEVICE auto step_interleave_begin(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) { + + if constexpr (kUseFusion) { + Fusion{}.before_softmax(acc_qk, count_qk, problem_shape); + } + + Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout())); + Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn)); + auto reduction_target_qk = reduction_target_n(tiled_mma_qk); + constexpr int red_rank = decltype(rank(reduction_target_qk))::value; + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + Tensor s_max_prev = make_fragment_like(s_max); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max_prev(i) = s_max(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + // Linear reduction is faster here, as well + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j)); + } + } + // reduce max + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target_qk); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = overload_max(s_max(i), __shfl_xor_sync(uint32_t(-1), s_max(i), stride(reduction_target_qk) * j)); + } + } + }); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_pv_mn); i++) { + float s_max_cur = s_max(i) == -INFINITY ? 0.0f : s_max(i); + float scale = ::exp2f((s_max_prev(i) - s_max_cur) * params.scale_softmax_log2); + a_sum(i) *= scale; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_pv_mn); j++) { + acc_pv_mn(i, j) *= scale; + } + } + } + + template + CUTLASS_DEVICE auto step_interleave_step(AccQK_MN& acc_qk_mn, State& state) { + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(acc_qk_mn); j++) { + float local_max = s_max(j) == -INFINITY ? 0.f : s_max(j); + float scale_max = params.scale_softmax_log2 * local_max; + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<1>(acc_qk_mn); k++) { + acc_qk_mn(j, k) = ::exp2f(params.scale_softmax_log2 * acc_qk_mn(j, k) - scale_max); + a_sum(j) += acc_qk_mn(j, k); + } + } + } + + template + CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) { + + if constexpr (kUseFusion) { + Fusion{}.before_softmax(acc_qk, count_qk, problem_shape); + } + + Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout())); + Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn)); + auto reduction_target_qk = reduction_target_n(tiled_mma_qk); + constexpr int red_rank = decltype(rank(reduction_target_qk))::value; + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + Tensor s_max_prev = make_fragment_like(s_max); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max_prev(i) = s_max(i); + + // Linear reduction is faster here, as well + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j)); + } + // reduce max + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target_qk); j *= 2) { + s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride(reduction_target_qk) * j)}); + } + }); + + MaxType local_max = s_max(i) == static_cast(-INFINITY) ? static_cast(0) : s_max(i); + MaxType scale = static_cast(params.scale_softmax_log2); + MaxType scale_max = scale * local_max; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max); + } + + MaxType s_max_cur = s_max(i) == static_cast(-INFINITY) ? static_cast(0) : s_max(i); + SumType scale_pv = overload_exp2((s_max_prev(i) - s_max_cur) * scale); + a_sum(i) *= scale_pv; + + using ElementPV = typename AccPV::value_type; + ElementPV scale_pv_ele = static_cast(scale_pv); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_pv_mn); j++) { + acc_pv_mn(i, j) *= scale_pv_ele; + } + a_sum(i) += SumType{reduce(acc_qk_mn(i, _), cute::plus{})}; + } + } + + + template + CUTLASS_DEVICE auto tail(State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv) { + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + auto reduction_target = reduction_target_n(tiled_mma_pv); + constexpr int red_rank = decltype(rank(reduction_target))::value; + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_pv_mn); i++) { + a_sum(i) = a_sum(i) + __shfl_xor_sync(uint32_t(-1), a_sum(i), stride(reduction_target) * j); + } + } + }); + + Tensor acc_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + Tensor lse = make_fragment_like(a_sum); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_mn); i++) { + float sum = a_sum(i); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : __frcp_rn(sum); + lse(i) = (sum == 0.f || sum != sum) ? INFINITY : s_max(i) * params.scale_softmax + __logf(sum); + float scale = params.rp_dropout * inv_sum; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_mn); j++) { + acc_mn(i, j) *= scale; + } + } + + return lse; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c3dfda63c7bdcbbcb7a711b328942101047b14cc --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma.hpp @@ -0,0 +1,526 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_collective_load.hpp" +#include "../collective/fmha_collective_softmax.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; +using cutlass::fmha::kernel::Tag; +using cutlass::fmha::kernel::find_option_t; + +template< + typename Element_, + typename ElementAccumulator_, + typename TileShape_, // BlockQO, BlockKV, BlockHead + class Fusion, + class... Options +> +struct FmhaMainloopTma { + + using Element = Element_; + using ElementAccumulator = ElementAccumulator_; + using TileShape = TileShape_; + + // Options + using kClusterM = find_option_t, Options...>; + static constexpr int StageCount = find_option_t, Options...>::value; + static constexpr int StageCountQ = find_option_t, Options...>::value; + + using StagesQ = cutlass::gemm::collective::StageCount; + using Stages = cutlass::gemm::collective::StageCount; + using ClusterShape = Shape; + + // 16B alignment lets us use TMA + static constexpr int Alignment = 16 / sizeof(Element); + + using TileShapeQK = TileShape; + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using LayoutQKV = cute::tuple>; + using LayoutQ = LayoutQKV; + using LayoutK = LayoutQKV; + using LayoutV = LayoutQKV; + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, LayoutQ, Alignment, + Element, LayoutK, Alignment, + ElementAccumulator, + TileShapeQK, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, LayoutK, Alignment, + Element, decltype(select<1,0,2>(LayoutV{})), Alignment, + ElementAccumulator, + TileShapePV, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{})); + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using MainloopPipelineQ = cutlass::PipelineTmaAsync; + + using PipelineState = typename cutlass::PipelineState; + using PipelineStateQ = typename cutlass::PipelineState; + + using TileShapeOut = TileShapePV; + using TiledMmaOut = TiledMmaPV; + using ElementOut = ElementAccumulator; + + struct SharedStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + struct Arguments { + const Element* ptr_Q; + LayoutQ dQ; + const Element* ptr_K; + LayoutK dK; + const Element* ptr_V; + LayoutV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + + float scale_softmax; + float scale_softmax_log2; + float rp_dropout; + }; + + using LoadQ = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kQ, + MainloopPipelineQ, + Element, + SmemLayoutQ, + TMA_Q + >; + + using LoadK = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kK, + MainloopPipeline, + Element, + SmemLayoutK, + TMA_K + >; + + using LoadV = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kV, + MainloopPipeline, + Element, + SmemLayoutV, + TMA_V + >; + + static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{})); + + static const int MaxThreadsPerBlock = size(typename CollectiveMmaQK::TiledMma{}); + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + return true + && (get<4>(problem_size) <= get<2>(TileShape{})) + && ((get<4>(problem_size) % Alignment) == 0) + && ((get<2>(problem_size) % Alignment) == 0) + ; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) { + + auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size))); + auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk, + typename CollectiveMmaQK::Arguments { + args.ptr_Q, args.dQ, + args.ptr_K, args.dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv, + typename CollectiveMmaPV::Arguments { + args.ptr_K, args.dK, // never used, dummy + args.ptr_V, select<1,0,2>(args.dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b, + 1.0f / (float) std::sqrt(get<4>(problem_size)), + (float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))), + 1.0f + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE auto + compute( + int block_rank_in_cluster, + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline, PipelineState& smem_pipe_read, PipelineState& smem_pipe_write, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, PipelineStateQ& smem_pipe_write_q, + SharedStorage& storage) + { + int warp_idx = cutlass::canonical_warp_idx_sync(); + int thread_idx = threadIdx.x; + + PipelineState smem_pipe_release = smem_pipe_read; + [[maybe_unused]] PipelineStateQ smem_pipe_release_q = smem_pipe_read_q; + + + int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size); + + LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q}; + auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, 1); + + LoadK load_k{params.tma_load_k, pipeline, storage.smem_k}; + auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count); + + LoadV load_v{params.tma_load_v, pipeline, storage.smem_v}; + auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count); + + // Set predicate for the lowest lane_id in the warp + int lane_predicate = cute::elect_one_sync(); + + // Issue TmaLoads (Prologue fetches) + if (warp_idx == 0) { + auto q_tile_iter = cute::make_coord_iterator(1); + int q_tile_count = 1; + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + + // Loop over K elems + auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count); + + int k_tile_count_tma = 2 * fusion_tile_count; + + uint16_t mcast_mask_b = 0; + + if (warp_idx == 0 && lane_predicate == 1) { + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{})); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < StageCount; i++) { + if (i % 2 == 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } else { + load_v.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + } + } + + TiledMmaQK tiled_mma_qk; + auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx); + + // Mainloop setup QK + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + + Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE) + + // Prepare: MMA PV + TiledMmaPV tiled_mma_pv; + auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + + // Mainloop setup PV + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE) + + int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size); + + pipeline_q.consumer_wait(smem_pipe_read_q); + + // mapping into QK accumulator + Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor tPcP = thr_mma_qk.partition_C(cP); + int m_block = get<0>(blk_coord); + tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{}); + + // Allocate PV acc + Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{})); + + cutlass::fmha::collective::CollectiveSoftmax softmax{params}; + auto softmax_state = softmax.init(acc_pv, tiled_mma_pv); + + if (true) + { + --k_tile_count; + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + + softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size); + + Tensor acc_qk_fixed = make_fragment_like(convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}))); + + Tensor acc_qk_input = make_tensor(acc_qk_fixed.data(), acc_qk.layout()); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + acc_qk_input(i) = static_cast(acc_qk(i)); + } + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + // + // Advance the pipe + // + + // Advance consumer pipeline + ++smem_pipe_read; + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + softmax.template step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + + pipeline.consumer_release(smem_pipe_release); + + ++smem_pipe_release; + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})); + + Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input); + + static_assert(decltype(size<1>(layout_qk_input) == _1{})::value); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tOrV); i++) { + Tensor acc_qk_element = make_fragment_like(layout_qk_input(_, _0{}, _0{})); + Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element); + Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i)); + softmax.step_interleave_step(acc_qk_input_mk, softmax_state); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(acc_qk_element_mk); j++) { + acc_qk_element_mk(j) = static_cast(acc_qk_input_mk(j)); + } + warpgroup_arrive(); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(tOrV); j++) { + cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j)); + } + } + warpgroup_commit_batch(); + + // Wait for the pipeline MMAs to drain + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + softmax.step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + + pipeline.consumer_release(smem_pipe_release); + + ++smem_pipe_release; + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})); + + Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input); + + static_assert(decltype(size<1>(layout_qk_input) == _1{})::value); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tOrV); i++) { + Tensor acc_qk_element = make_fragment_like(layout_qk_input(_, _0{}, _0{})); + Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element); + Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i)); + softmax.step_interleave_step(acc_qk_input_mk, softmax_state); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(acc_qk_element_mk); j++) { + acc_qk_element_mk(j) = static_cast(acc_qk_input_mk(j)); + } + warpgroup_arrive(); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(tOrV); j++) { + cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j)); + } + } + warpgroup_commit_batch(); + + // Wait for the pipeline MMAs to drain + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_pv); + + Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv); + + return make_tuple(acc_pv, lse); + } +}; + +} // namespace cutlass::fmha::collective + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9fd45baf3806c7246a4c358273bc544ca941d022 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma_warpspecialized.hpp @@ -0,0 +1,560 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_collective_load.hpp" +#include "../collective/fmha_collective_softmax.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; +using cutlass::fmha::kernel::Tag; +using cutlass::fmha::kernel::find_option_t; + +template< + class Element_, + class ElementAccumulatorQK_, + class ElementAccumulatorPV_, + class TileShape_, // SeqQ, SeqKV, Head + class LayoutQ_, class LayoutK_, class LayoutV_, // SeqX, Head, (Batches) + class Fusion, + class... Options +> +struct FmhaMainloopTmaWarpSpecialized { + + using Element = Element_; + using ElementAccumulatorQK = ElementAccumulatorQK_; + using ElementAccumulatorPV = ElementAccumulatorPV_; + using TileShape = TileShape_; + + using LayoutQ = LayoutQ_; + using LayoutK = LayoutK_; + using LayoutV = LayoutV_; + + // Options + static constexpr bool kIsPersistent = find_option_t::value; + static constexpr bool kIsMainloopLocked = find_option_t::value; + + static constexpr int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = find_option_t, Options...>::value; + static constexpr int StageCount = find_option_t, Options...>::value; + static constexpr int StageCountQ = find_option_t, Options...>::value; + + static const int kOuterLoads = 1; + using StagesQ = cutlass::gemm::collective::StageCount; + using Stages = cutlass::gemm::collective::StageCount; + using ClusterShape = Shape<_1, _1, _1>; + static_assert(StagesQ::value >= NumMmaWarpGroups); + static_assert(Stages::value >= 2); + + // 16B alignment lets us use TMA + static constexpr int Alignment = 16 / sizeof(Element); + + using TileShapeQK = Shape< + decltype(tuple_element_t<0, TileShape>{} / Int{}), + tuple_element_t<1, TileShape>, + tuple_element_t<2, TileShape>>; + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, LayoutQ, Alignment, + Element, LayoutK, Alignment, + ElementAccumulatorQK, + TileShapeQK, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, LayoutK, Alignment, + Element, decltype(select<1,0,2>(LayoutV{})), Alignment, + ElementAccumulatorPV, + TileShapePV, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{})); + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using MainloopPipelineQ = cutlass::PipelineTmaAsync; + + using PipelineState = typename cutlass::PipelineState; + using PipelineStateQ = typename cutlass::PipelineState; + + static constexpr int kInnerLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); + static constexpr int kOuterLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); + + using TileShapeOut = TileShapePV; + using TiledMmaOut = TiledMmaPV; + using ElementOut = ElementAccumulatorPV; + + struct SharedStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + struct Arguments { + const Element* ptr_Q; + LayoutQ dQ; + const Element* ptr_K; + LayoutK dK; + const Element* ptr_V; + LayoutV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + + float scale_softmax; + float scale_softmax_log2; + float rp_dropout; + }; + + using LoadQ = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kQ, + MainloopPipelineQ, + Element, + SmemLayoutQ, + TMA_Q + >; + + using LoadK = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kK, + MainloopPipeline, + Element, + SmemLayoutK, + TMA_K + >; + + using LoadV = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kV, + MainloopPipeline, + Element, + SmemLayoutV, + TMA_V + >; + + static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{})); + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + return true + && (get<4>(problem_size) <= get<2>(TileShape{})) + && ((get<4>(problem_size) % Alignment) == 0) + && ((get<2>(problem_size) % Alignment) == 0) + ; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) { + + auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size))); + auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk, + typename CollectiveMmaQK::Arguments { + args.ptr_Q, args.dQ, + args.ptr_K, args.dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv, + typename CollectiveMmaPV::Arguments { + args.ptr_K, args.dK, // never used, dummy + args.ptr_V, select<1,0,2>(args.dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b, + 1.0f / (float) std::sqrt(get<4>(problem_size)), + (float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))), + 1.0f + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load_kv_maybe_q( + int block_rank_in_cluster, + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline, PipelineState& smem_pipe_write, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size); + + int lane_predicate = cute::elect_one_sync(); + + uint16_t mcast_mask_b = 0; + + if (lane_predicate == 1) { + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{})); + } + } + } + + auto q_tile_iter = cute::make_coord_iterator(Int{}); + [[maybe_unused]] int q_tile_count = NumMmaWarpGroups; + + auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count); + int k_tile_count = 2 * fusion_tile_count; + + LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q}; + auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups); + + LoadK load_k{params.tma_load_k, pipeline, storage.smem_k}; + auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count); + + LoadV load_v{params.tma_load_v, pipeline, storage.smem_v}; + auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count); + + if constexpr (kLoadQ) { + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + + if constexpr (kLoadQ) { + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + + if constexpr (! kLoadQ) { + if (do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + } + + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + + if constexpr (kLoadQ) { + while (q_tile_count > 0) { + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + } + } + + template + CUTLASS_DEVICE void + load_maybe_q( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + int lane_predicate = cute::elect_one_sync(); + + LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q}; + auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups); + + auto q_tile_iter = cute::make_coord_iterator(Int{}); + + CUTLASS_PRAGMA_UNROLL + for (int q_tile_count = 0; q_tile_count < NumMmaWarpGroups; q_tile_count++) { + int count = 1; + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, count); + if (q_tile_count == 0 && do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + } + } + + template + CUTLASS_DEVICE void + reduce( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer, + SharedStorage& storage) + { /* no-op */ } + + template + CUTLASS_DEVICE auto + compute( + BlkCoord const& blk_coord, BlkCoord const& wg_coord, + Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline, PipelineState& smem_pipe_read, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, + MainloopPipelineReducer&, PipelineStateReducer&, + SharedStorage& storage, + MathWgOrderBarrier& math_wg_order_barrier) + { + int thread_idx = int(threadIdx.x); + + PipelineState smem_pipe_release = smem_pipe_read; + PipelineStateQ smem_pipe_release_q = smem_pipe_read_q; + + TiledMmaQK tiled_mma_qk; + auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx); + + // Mainloop setup QK + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + + Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE) + + // Prepare: MMA PV + TiledMmaPV tiled_mma_pv; + auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + + // Mainloop setup PV + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE) + + int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size); + + pipeline_q.consumer_wait(smem_pipe_read_q); + + // mapping into QK accumulator + Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor tPcP = thr_mma_qk.partition_C(cP); + int m_block = get<0>(wg_coord); + tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{}); + + // Allocate PV acc + Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{})); + + cutlass::fmha::collective::CollectiveSoftmax softmax{params}; + auto softmax_state = softmax.init(acc_pv, tiled_mma_pv); + + if (true) + { + --k_tile_count; + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + math_wg_order_barrier.wait(); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + math_wg_order_barrier.arrive(); + + ++smem_pipe_read; + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + + softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size); + + Tensor acc_qk_fixed = make_acc_into_op(acc_qk, typename TiledMmaPV::LayoutA_TV{}); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + // Advance consumer pipeline + ++smem_pipe_read; + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) + { + --k_tile_count; + + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + auto tok = pipeline.consumer_try_wait(smem_pipe_read); + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait(); + softmax.template step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive(); + + Tensor acc_qk_fixed = make_acc_into_op(acc_qk, typename TiledMmaPV::LayoutA_TV{}); + + pipeline.consumer_wait(smem_pipe_read, tok); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) + { + --k_tile_count; + + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + auto tok = pipeline.consumer_try_wait(smem_pipe_read); + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + //if constexpr (kIsPersistent) + // if (k_tile_count == 0) pipeline_q.consumer_release(smem_pipe_release_q); + + if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait(); + softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive(); + + Tensor acc_qk_fixed = make_acc_into_op(acc_qk, typename TiledMmaPV::LayoutA_TV{}); + + pipeline.consumer_wait(smem_pipe_read, tok); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + if (kIsPersistent) pipeline_q.consumer_release(smem_pipe_release_q); + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_pv); + + if (kIsPersistent) pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv); + + return make_tuple(acc_pv, lse); + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_common.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..40c23f48bbebb9296292149e193e6e8bfa2e3752 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_common.hpp @@ -0,0 +1,245 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/kernel_hardware_info.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + if constexpr (rA == 2 && rB == 2 && rC == 1) { + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<1>(tA); k_block++) { + cute::gemm(atom, tA(_,k_block), tB(_,k_block), tC); + atom.accumulate_ = GMMA::ScaleOut::One; + } + } else { + static_assert(rA == 3 && rB == 3 && rC == 3); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = GMMA::ScaleOut::One; + } + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = GMMA::ScaleOut::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template +CUTE_DEVICE constexpr typename T::value_type reduce(T const& t, Fn fn) { + if constexpr (decltype(size(t) % _2{} == _0{})::value) { + auto partial = make_tensor(size(t) / _2{}); + CUTE_UNROLL + for (int i = 0; i < size(partial); i++) { + partial(i) = fn(t(i), t(i + size(partial))); + } + return reduce(partial, fn); + } else { + auto result = t(_0{}); + CUTE_UNROLL + for (int i = 1; i < size(t); i++) { + result = fn(result, t(i)); + } + return result; + } +} + +struct fmha_max { + CUTE_DEVICE float operator()(float a, float b) { return ::max(a, b); } +}; + +template +inline auto __device__ constexpr layout_separate(Threshold const& thr, + Source const& src, Reference const& ref) { + auto lt = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) { + if constexpr(decltype(r < thr)::value) { + return s; + } else { + return make_layout(_1{}, _0{}); + } + })); + auto ge = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) { + if constexpr(decltype(r >= thr)::value) { + return s; + } else { + return make_layout(_1{}, _0{}); + } + })); + return make_layout(lt, ge); +} + +template +inline auto __device__ constexpr layout_acc_mn(TiledMma const& tiled_mma, Acc const& acc) { + auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}), + get<0>(acc), stride<1>(typename TiledMma::LayoutC_TV{})); + auto V_M = get<0>(separated); + auto V_N = get<1>(separated); + return make_layout(make_layout(V_M, get<1>(acc)), make_layout(V_N, get<2>(acc))); +} + +template +inline auto __device__ constexpr layout_op_mk_v(TiledMma const& tiled_mma, Acc const& acc) { + return layout_separate(get<0>(typename TiledMma::Shape_MNK{}), + get<0>(acc), stride<1>(typename TiledMma::LayoutA_TV{})); +} + +template +inline auto __device__ constexpr tensor_op_mk_v(TiledMma const& tiled_mma, Acc&& acc) { + return make_tensor(acc.data(), layout_op_mk_v(tiled_mma, acc.layout())); +} + +template +inline auto __device__ constexpr reduction_target_n(TiledMma const& tiled_mma) { + auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}), + make_layout(shape<0>(typename TiledMma::LayoutC_TV{})), + stride<0>(typename TiledMma::LayoutC_TV{})); + return get<1>(separated); +} + + +template class Primitive, cute::GMMA::Major tA, cute::GMMA::Major tB, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB> +inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom> const& tiled_mma) { + using Atom = cute::MMA_Atom>; + using ElementA = typename Atom::ValTypeA; + using ElementB = typename Atom::ValTypeB; + using ElementC = typename Atom::ValTypeC; + using Shape_MNK = typename Atom::Shape_MNK; + using RS = decltype(cute::GMMA::rs_op_selector()); + return cute::MMA_Atom{}; +} + +template class Primitive, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB> +inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom> const& tiled_mma) { + using Atom = cute::MMA_Atom>; + using ElementA = typename Atom::ValTypeA; + using ElementB = typename Atom::ValTypeB; + using ElementC = typename Atom::ValTypeC; + using Shape_MNK = typename Atom::Shape_MNK; + constexpr auto tA = cute::GMMA::Major::K; + constexpr auto tB = cute::GMMA::Major::K; + using RS = decltype(cute::GMMA::rs_op_selector()); + return cute::MMA_Atom{}; +} + +template +CUTE_DEVICE auto constexpr convert_to_gmma_rs(cute::TiledMMA const& tiled_mma) { + return cute::TiledMMA{}; +} + +template +CUTE_DEVICE auto constexpr convert_c_layout_to_a_layout(CLayout const& c, AValueShape const& a) { + return make_layout( + make_shape(a, shape<1>(c), make_shape(shape<2>(c), size<0>(c) / size(a))), + make_stride(stride<0>(c), stride<1>(c), make_stride(stride<2>(c), size<2>(a) * stride<0,2>(c)))); +} + +template +CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, make_layout(stages))); +} + +template +CUTE_DEVICE auto make_acc_into_op(Accumulator const& acc, OperandLayout_TV const& operand_layout_tv) { + Tensor operand = make_fragment_like(convert_c_layout_to_a_layout(acc.layout(), shape<1>(operand_layout_tv))); + Tensor operand_as_acc = make_tensor(operand.data(), acc.layout()); + + cute::copy(acc, operand_as_acc); + + if constexpr (sizeof(Element) == 1) { + + // 00 11 22 33 00 11 22 33 acc layout + // 00 00 11 11 22 22 33 33 operand layout + // BB AA AA BB AA BB BB AA conflict-free exchange pattern + // 16-bit exchange; so process two at a time potentially + int tid = threadIdx.x % 4; + auto values_u32 = recast(operand); + + CUTE_UNROLL + for (int n = 0; n < size<1>(values_u32); n++) { + CUTE_UNROLL + for (int k = 0; k < size<2>(values_u32); k++) { + CUTE_UNROLL + for (int ii = 0; ii < 8; ii += 4) { + + uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k); + uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k); + + // step A: + // t 1 v 0 -> t 0 v 1 + // t 2 v 0 -> t 1 v 0 + // t 0 v 1 -> t 2 v 0 + // t 3 v 1 -> t 3 v 1 + + int v_to_send = tid == 1 || tid == 2 ? 0 : 1; + int v_to_recv = v_to_send; + int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF; + + uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1; + + values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4); + + // step B: + // t 0 v 0 -> t 0 v 0 + // t 3 v 0 -> t 1 v 1 + // t 1 v 1 -> t 2 v 1 + // t 2 v 1 -> t 3 v 0 + + v_to_send = 1 - v_to_send; + v_to_recv = 1 - v_to_recv; + t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF; + + uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1; + + values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4); + + values_u32(ii / 2 + 0, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410); + values_u32(ii / 2 + 1, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632); + } + } + } + } + + return operand; +} + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue.hpp new file mode 100644 index 0000000000000000000000000000000000000000..246a625ff7bf3591558e11cb42d53e64750a865a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue.hpp @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../collective/fmha_common.hpp" + +namespace cutlass::fmha::collective { + +template +struct FmhaFwdEpilogue { + + static constexpr int Alignment = 16 / sizeof(Element); + + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + void, cute::tuple>, Alignment, + Element, cute::tuple>, Alignment, + cutlass::epilogue::TmaWarpSpecialized, + DefaultOperation + >::CollectiveOp; + + struct Arguments { + Element* ptr_O; + cute::tuple> dO; + + ElementAccumulator* ptr_LSE; + cute::tuple> dLSE; + }; + + struct Params { + ElementAccumulator* ptr_LSE; + cute::tuple> dLSE; + + typename CollectiveEpilogueTMA::Params epilogue_TMA; + }; + + using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage; + using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage; + using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline; + static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes; + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) { + auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), 1, + make_shape(get<0>(problem_size), get<1>(problem_size))); + typename CollectiveEpilogueTMA::Arguments args_tma{{}, args.ptr_O, args.dO, args.ptr_O, args.dO}; + return Params{ + args.ptr_LSE, args.dLSE, + CollectiveEpilogueTMA::to_underlying_arguments(problem_size_o, args_tma, workspace) + }; + } + + template + CUTLASS_DEVICE void operator()( + TileShape const& tile_shape, BlkCoord const& blk_coord, + ResultTuple const& result, TiledMma const& tiled_mma, + ProblemShape const& problem_size, Params const& params, + LoadPipeline epi_load_pipeline, + TensorStorage& epi_tensor_storage) + { + using X = Underscore; + + auto acc = get<0>(result); + auto lse = get<1>(result); + + auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x); + + int seqlen_q = get<2>(problem_size); + int num_batch = get<0>(problem_size); + int num_heads = get<1>(problem_size); + // Epilogue for lse + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), + make_shape(seqlen_q, get<1>(tile_shape), make_shape(num_batch, num_heads)), + make_stride(_1{}, _0{}, get<1>(params.dLSE))); + Tensor gLSE_full = local_tile(mLSE, tile_shape, make_coord(_, _, _), Step<_1, _1, X>{}); + Tensor gLSE = gLSE_full(_, _, get<0>(blk_coord), get<1>(blk_coord), get<2>(blk_coord)); + Tensor tOgLSE = thr_mma.partition_C(gLSE); + Tensor cO = make_identity_tensor(take<0,2>(tile_shape)); + Tensor tOcO = thr_mma.partition_C(cO); + if (get<1>(tOcO(_0{})) == 0) { + auto tOgLSE_mn = make_tensor(tOgLSE.data(), layout_acc_mn(tiled_mma, tOgLSE.layout())); + auto tOcO_mn = make_tensor(tOcO.data(), layout_acc_mn(tiled_mma, tOcO.layout())); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tOgLSE_mn); i++) { + if (get<0>(tOcO_mn(i)) + get<0>(blk_coord) * get<0>(tile_shape) < get<2>(problem_size)) { + tOgLSE_mn(i, _0{}) = lse(i); + } + } + } + auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), _, + make_shape(get<0>(problem_size), get<1>(problem_size))); + + CollectiveEpilogueTMA epilogue_tma(params.epilogue_TMA, epi_tensor_storage); + + using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state; + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + epilogue_tma.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_size_o, tile_shape, make_coord(get<0>(blk_coord), _0{}, _, get<2>(blk_coord)), + acc, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup, + epi_tensor_storage + ); + + epilogue_tma.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next + ); + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..677722529f5ea59e5a7b1f37823b9ad96506f2a6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "../collective/fmha_epilogue.hpp" + +namespace cutlass::fmha::collective { + +template +struct FmhaBwdEpilogueKV { + + static constexpr int Alignment = 16 / sizeof(Element); + + struct Arguments { + Element* ptr_K; + cute::tuple dK; + + Element* ptr_V; + cute::tuple dV; + }; + + //using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using DefaultOperation = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute, + cutlass::epilogue::fusion::Sm90AccFetch + >; + using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + void, cute::tuple>, Alignment, + Element, cute::tuple>, Alignment, + cutlass::epilogue::TmaWarpSpecialized, + DefaultOperation + >::CollectiveOp; + + struct Params { + typename CollectiveEpilogueTMA::Params epilogue_K; + typename CollectiveEpilogueTMA::Params epilogue_V; + }; + + + using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage[2]; + using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage; + using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline; + static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes; + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) { + auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), + make_stride(get<0>(args.dK), get<1>(args.dK))); + auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), + make_stride(get<0>(args.dV), get<1>(args.dV))); + + auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), 1, + make_shape(get<0>(problem_size), get<1>(problem_size))); + typename CollectiveEpilogueTMA::Arguments args_k{{}, args.ptr_K, dK, args.ptr_K, dK}; + typename CollectiveEpilogueTMA::Arguments args_v{{}, args.ptr_V, dV, args.ptr_V, dV}; + return Params{ + CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_k, nullptr), + CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_v, nullptr) + }; + } + + template + CUTLASS_DEVICE void operator()( + TileShape const& tile_shape, BlkCoord const& blk_coord, + ResultTuple const& result, TiledMma const& tiled_mma, + ProblemShape const& problem_size, Params const& params, + LoadPipeline epi_load_pipeline, TensorStorage& epi_tensor_storage) + { + auto acc_k = get<0>(result); + auto acc_v = get<1>(result); + + auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), _, + make_shape(get<0>(problem_size), get<1>(problem_size))); + + using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state; + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CollectiveEpilogueTMA epilogue_k{params.epilogue_K, epi_tensor_storage[0]}; + CollectiveEpilogueTMA epilogue_v{params.epilogue_V, epi_tensor_storage[1]}; + + { + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + epilogue_k.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)), + acc_k, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup, + epi_tensor_storage[0] + ); + + } + + { + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + epilogue_v.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)), + acc_v, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup, + epi_tensor_storage[1] + ); + + epilogue_k.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next + ); + + epilogue_v.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next + ); + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_fusion.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_fusion.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ac51f250763e6890673d57d8238f0be464f70e77 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_fusion.hpp @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct DefaultFusion { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return ceil_div(get<3>(problem_size), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return 0; + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + + ) { + return; + } +}; + +struct ResidualFusion : DefaultFusion { + + using Base = DefaultFusion; + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return 1; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + ) { + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) >= get<3>(problem_size)) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct CausalFusion : DefaultFusion { + + using Base = DefaultFusion; + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + // See note below on different ways to think about causal attention + // Again, we'd add the offset_q into the max_blocks_q calculation + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return ceil_div(get<0>(tile_shape), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + ) { + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<0>(pos) < get<1>(pos)) { + acc_qk(i) = -INFINITY; + } + } + } + +}; + +template +struct FusionBwdAdapter { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return Base{}.get_trip_count(select<1,0,2>(blk_coord), select<1,0,2>(tile_shape), select<0,1,3,2,4>(problem_size)); + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + ) { + auto index_base = index_qk(_0{}); + auto index_shape = shape(index_qk); + auto index_stride = transform_leaf(stride(index_qk), [](auto elem) { + if constexpr (is_scaled_basis::value) { + if constexpr(decltype(elem.mode() == _0{})::value) { + return ScaledBasis(elem.value()); + } else { + return ScaledBasis(elem.value()); + } + } else { + return elem; + } + }); + auto index_qk_bwd = make_tensor(make_inttuple_iter(select<1,0>(index_base)), make_layout(index_shape, index_stride)); + Base{}.before_softmax(acc_qk, index_qk_bwd, problem_size); + } + + template + CUTLASS_DEVICE + bool is_contributing( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return true; + } +}; + +template<> +struct FusionBwdAdapter { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get<2>(problem_size) / get<0>(TileShape{}); + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + + ) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) < get<0>(pos)) { + acc_qk(i) = -INFINITY; + } + } + } + + template + CUTLASS_DEVICE + bool is_contributing( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + int max_q = get<0>(blk_coord) * get<0>(tile_shape) + get<0>(tile_shape); + int min_k = get<1>(blk_coord) * get<1>(tile_shape); + return min_k <= max_q; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/device_universal.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/device_universal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..beadb688730bd3a7936a7885ab72dd60b903c326 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/device_universal.hpp @@ -0,0 +1,278 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class Universal { +public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return Kernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("Universal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("Universal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/fmha_device_bwd.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/fmha_device_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..47c43b716b1e318f8504eb4220a66692a6fb1106 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/fmha_device_bwd.hpp @@ -0,0 +1,299 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +// common +#include "cutlass/cutlass.h" + +#include "../device/device_universal.hpp" +#include "../collective/fmha_collective_bwd_tma_warpspecialized.hpp" +#include "../collective/fmha_fusion.hpp" +#include "../collective/fmha_epilogue_bwd.hpp" +#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" +#include "../kernel/fmha_kernel_bwd_convert.hpp" +#include "../kernel/fmha_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_tile_scheduler.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class FmhaBwd { +public: + /// Argument structure: User API + struct Arguments { + cute::tuple problem_size; + + const Element* ptr_Q; + cute::tuple stride_Q; + const Element* ptr_K; + cute::tuple stride_K; + const Element* ptr_V; + cute::tuple stride_V; + + const Element* ptr_O; + cute::tuple stride_O; + const ElementAccumulator* ptr_LSE; + cute::tuple stride_LSE; + + const Element* ptr_dO; + cute::tuple stride_dO; + + Element* ptr_dQ; + cute::tuple stride_dQ; + Element* ptr_dK; + cute::tuple stride_dK; + Element* ptr_dV; + cute::tuple stride_dV; + + cutlass::KernelHardwareInfo hw_info; + }; + + using OperationSumOdO = cutlass::device::Universal>; + using OperationConvert = cutlass::device::Universal>; + + using Mainloop = cutlass::fmha::collective::FmhaBwdMainloopTmaWarpSpecialized< + Element, ElementAccumulator, TileShape, + cutlass::fmha::collective::FusionBwdAdapter, Options...>; + + using Epilogue = cutlass::fmha::collective::FmhaBwdEpilogueKV; + + using Operation = cutlass::device::Universal< + cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized< + Mainloop, + Epilogue, + cutlass::fmha::kernel::TileSchedulerBwdAdapter, Options...>>; + + struct Params { + OperationSumOdO op_sum_OdO; + Operation op; + OperationConvert op_convert; + ElementAccumulator* dQ_acc; + size_t dQ_acc_size; + }; + +private: + Params params_; + + static typename OperationSumOdO::Arguments to_sum_OdO_arguments(Arguments const& args, ElementAccumulator* dest = nullptr) { + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + auto stride_sum_OdO = make_stride(H*Q, Q, _1{}); + return typename OperationSumOdO::Arguments { + args.problem_size, + args.ptr_O, args.stride_O, + args.ptr_dO, args.stride_dO, + dest, stride_sum_OdO + }; + } + + static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + auto stride_src_dQ = make_stride(B == 1 ? 0 : (H*Q*D), Q*D, D, _1{}); + return typename OperationConvert::Arguments { + args.problem_size, + src, stride_src_dQ, + nullptr, stride_src_dQ, + nullptr, stride_src_dQ, + args.ptr_dQ, args.stride_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV + }; + } + + static typename Operation::Arguments to_bwd_arguments( + Arguments const& args, + ElementAccumulator* sum_OdO = nullptr, cute::tuple const& stride_sum_OdO = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple const& stride_dQ = {} + ) { + return typename Operation::Arguments{ + args.problem_size, + { args.ptr_Q, args.stride_Q, + args.ptr_K, args.stride_K, + args.ptr_V, args.stride_V, + args.ptr_dO, args.stride_dO, + args.ptr_LSE, args.stride_LSE, + sum_OdO, stride_sum_OdO, + dQ_acc, stride_dQ }, + { args.ptr_dK, args.stride_dK, + args.ptr_dV, args.stride_dV }, + args.hw_info + }; + } + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + Status status = Status::kSuccess; + + status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = OperationConvert::can_implement(to_convert_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = Operation::can_implement(to_bwd_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + size_t workspace_bytes = 0; + // OdO vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + return workspace_bytes; + } + + /// Initializes state from arguments. + Status + initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" + << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); + + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); + params_.dQ_acc = dQ_acc; + params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO); + auto args_convert = to_convert_arguments(args, dQ_acc); + params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); + params_.op_convert.initialize(args_convert, nullptr, stream); + auto args_bwd = to_bwd_arguments(args, sum_OdO, args_sum_OdO.stride_sum_OdO, dQ_acc, args_convert.stride_src_dQ); + params_.op.initialize(args_bwd, nullptr, stream); + + return Status::kSuccess; + } + + /// Initializes state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + char* workspace_chr = reinterpret_cast(workspace); + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); + return initialize_split(args, dQ_acc, sum_OdO, stream); + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); + + Status result = Status::kSuccess; + result = params.op_sum_OdO.run(stream); + if (result != Status::kSuccess) { + return result; + } + + auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); + if (cuda_result != cudaSuccess) { + return Status::kErrorInternal; + } + result = params.op.run(stream); + if (result != Status::kSuccess) { + return result; + } + + result = params.op_convert.run(stream); + if (result != Status::kSuccess) { + return result; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bd0e0f9d51620e90545f390cd397a15b0e7873ed --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "../collective/fmha_collective_tma.hpp" +#include "../collective/fmha_collective_tma_warpspecialized.hpp" +#include "../collective/fmha_epilogue.hpp" +#include "../kernel/fmha_kernel_tma.hpp" +#include "../kernel/fmha_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::kernel { + +template< + class Element_, + class ElementAccumulatorQK_, + class ElementAccumulatorPV_, + class TileShape_, // BlockQO, BlockKV, BlockHead + class LayoutQ_, + class LayoutK_, + class LayoutV_, + class Fusion, + class DispatchPolicy, + class... Options +> +struct FmhaBuilder; + +template< + class Element, + class ElementAccumulator, + class TileShape, // BlockQO, BlockKV, BlockHead + class Fusion, + class... Options +> +struct FmhaBuilder< + Element, + ElementAccumulator, + ElementAccumulator, + TileShape, + cute::tuple>, + cute::tuple>, + cute::tuple>, + Fusion, + cutlass::gemm::KernelTma, + Options... +> { + + using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTma; + + using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue< + Element, ElementAccumulator, typename CollectiveMainloop::TileShapePV>; + + using Kernel = cutlass::fmha::kernel::FmhaKernelTma; +}; + +template< + class Element, + class ElementAccumulatorQK, + class ElementAccumulatorPV, + class TileShape, // BlockQO, BlockKV, BlockHead + class LayoutQ, + class LayoutK, + class LayoutV, + class Fusion, + class... Options +> +struct FmhaBuilder< + Element, + ElementAccumulatorQK, + ElementAccumulatorPV, + TileShape, + LayoutQ, + LayoutK, + LayoutV, + Fusion, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Options... +> { + + using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTmaWarpSpecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, + TileShape, LayoutQ, LayoutK, LayoutV, + Fusion, Options...>; + + using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue< + Element, ElementAccumulatorPV, typename CollectiveMainloop::TileShapePV>; + + static constexpr bool kIsPersistent = find_option_t::value; + using TileScheduler = std::conditional_t; + + using Kernel = cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized; +}; + +template< + class Element, + class ElementAccumulatorQK, + class ElementAccumulatorPV, + class TileShape, // BlockQO, BlockKV, BlockHead + class LayoutQ, + class LayoutK, + class LayoutV, + class Fusion, + class... Options +> +struct FmhaBuilder< + Element, + ElementAccumulatorQK, + ElementAccumulatorPV, + TileShape, + LayoutQ, + LayoutK, + LayoutV, + Fusion, + cutlass::gemm::KernelTmaWarpSpecializedPingpong, + Options... +> { + using Kernel = typename FmhaBuilder< + Element, ElementAccumulatorQK, ElementAccumulatorPV, + TileShape, + LayoutQ, LayoutK, LayoutV, + Fusion, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Options..., + Option, + Option + >::Kernel; +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3ef1946ba6056c51bf882124d7d132e98b984ae2 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdConvert { + + struct Arguments { + tuple problem_size; + + const ElementAccumulator* ptr_src_dQ; + tuple stride_src_dQ; + const ElementAccumulator* ptr_src_dK; + tuple stride_src_dK; + const ElementAccumulator* ptr_src_dV; + tuple stride_src_dV; + + Element* ptr_dest_dQ; + tuple stride_dest_dQ; + Element* ptr_dest_dK; + tuple stride_dest_dK; + Element* ptr_dest_dV; + tuple stride_dest_dV; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static const int kBlockSeq = 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kNumThreadsD = 16; + static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 4; + + static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; + + static bool can_implement(Arguments const& args) { + return get<4>(args.problem_size) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(size<0>(params.problem_size), size<1>(params.problem_size), ceil_div(std::max(size<2>(params.problem_size), size<3>(params.problem_size)), kBlockSeq)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsSeq, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAccumulator* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) { + auto ptr_src_bh = ptr_src + get<0>(stride_src) * blockIdx.x + get<1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<0>(stride_dest) * blockIdx.x + get<1>(stride_dest) * blockIdx.y; + + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { + int idx_s = idx_s_t + kBlockSeq * blockIdx.z; + if (idx_s >= count) continue; + auto ptr_src_bhs = ptr_src_bh + idx_s * get<2>(stride_src); + auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<2>(stride_dest); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + ElementAccumulator value_src[kElementsPerLoad]; + Element value_dest[kElementsPerLoad]; + + using VecSrc = uint_bit_t * kElementsPerLoad>; + using VecDest = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + value_dest[v] = value_src[v]; + } + + *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); + } + } + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + if (params.ptr_src_dQ != nullptr) { + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<2>(params.problem_size)); + } + if (params.ptr_src_dK != nullptr) { + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<3>(params.problem_size)); + } + if (params.ptr_src_dV != nullptr) { + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<3>(params.problem_size)); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4e276a6235107463369ce50b7ab5e50b33a55350 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdSumOdO { + + struct Arguments { + cute::tuple problem_size; + + const Element* ptr_O; + cute::tuple stride_O; + const Element* ptr_dO; + cute::tuple stride_dO; + + ElementAccumulator* ptr_sum_OdO; + cute::tuple stride_sum_OdO; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kBlockQ = 16; + + static const int kNumThreadsD = 8; + static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 2; + + static const int kIterationsQ = kBlockQ / kNumThreadsQ; + + static bool can_implement(Arguments const& args) { + return get<4>(args.problem_size) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(ceil_div(size<2>(params.problem_size), kBlockQ), size<1>(params.problem_size), size<0>(params.problem_size)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsQ, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<1>(params.stride_O) + blockIdx.z * get<0>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<1>(params.stride_dO) + blockIdx.z * get<0>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1>(params.stride_sum_OdO) + blockIdx.z * get<0>(params.stride_sum_OdO); + + CUTLASS_PRAGMA_UNROLL + for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { + int idx_q = idx_q_t + kBlockQ * blockIdx.x; + if (idx_q >= get<2>(params.problem_size)) continue; + ElementAccumulator acc = 0; + auto ptr_O_bhq = ptr_O_bh + idx_q * get<2>(params.stride_O); + auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<2>(params.stride_dO); + auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<2>(params.stride_sum_OdO); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + Element value_O[kElementsPerLoad]; + Element value_dO[kElementsPerLoad]; + + using Vec = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); + *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + acc += value_O[v] * value_dO[v]; + } + } + + for (int i = 1; i < kNumThreadsD; i *= 2) { + acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); + } + + if (threadIdx.x == 0) { + *ptr_sum_OdO_bhq = acc; + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eeee7fa286a7d1039ff0518f4446a9edd3055521 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp @@ -0,0 +1,226 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/arch/arch.h" + +#include "../kernel/fmha_tile_scheduler.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::kernel { + +template< + class CollectiveMainloop, + class CollectiveEpilogue, + class... Options +> +struct FmhaKernelTma { + + // Options + static constexpr int kBlocksPerSM = find_option_t, Options...>::value; + + using Element = typename CollectiveMainloop::Element; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + + using TileScheduler = IndividualTileScheduler; + + using StagesQ = typename CollectiveMainloop::StagesQ; + using Stages = typename CollectiveMainloop::Stages; + + using TileShape = typename CollectiveMainloop::TileShape; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ; + + using SmemLayoutQ = typename CollectiveMainloop::SmemLayoutQ; + using SmemLayoutK = typename CollectiveMainloop::SmemLayoutK; + + struct SharedStorage { + union { + typename CollectiveMainloop::SharedStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + using PipelineStorageQ = typename MainloopPipelineQ::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + alignas(16) PipelineStorageQ pipeline_storage_q; + + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + alignas(16) EpiLoadPipelineStorage epi_load; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using ProblemShape = cute::tuple; + + struct Arguments { + ProblemShape problem_size; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_size; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + using PipelineParamsQ = typename MainloopPipelineQ::Params; + using PipelineStateQ = typename cutlass::PipelineState; + + static const int MinBlocksPerMultiprocessor = kBlocksPerSM; + static const int MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + using ArchTag = cutlass::arch::Sm90; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_size, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_size, + CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + TileScheduler tile_scheduler{params.tile_scheduler}; + + // Shared memory. + auto& storage = *reinterpret_cast(smem); + + int thread_idx = int(threadIdx.x); + + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; + int lane_predicate = cute::elect_one_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + + PipelineParamsQ pipeline_params_q; + pipeline_params_q.transaction_bytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); // Q + pipeline_params_q.role = MainloopPipelineQ::ThreadCategory::ProducerConsumer; + pipeline_params_q.is_leader = warp_group_thread_idx == 0; + pipeline_params_q.num_consumers = cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); // KV + pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = cutlass::NumThreadsPerWarpGroup; + + MainloopPipelineQ pipeline_q(storage.pipeline_storage_q, pipeline_params_q, Shape<_1, _1, _1>{}); + MainloopPipeline pipeline(storage.pipeline_storage, pipeline_params, ClusterShape{}); + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::ProducerConsumer; + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params); + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineState smem_pipe_read; + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + + PipelineStateQ smem_pipe_read_q; + PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state(); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + // and to finish smem init + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } + else { + __syncthreads(); + } + + auto blk_coord = tile_scheduler.get_block_coord(); + + CollectiveMainloop collective_mainloop; + auto result = collective_mainloop.compute( + block_rank_in_cluster, + blk_coord, params.mainloop, params.problem_size, + pipeline, smem_pipe_read, smem_pipe_write, + pipeline_q, smem_pipe_read_q, smem_pipe_write_q, + storage.mainloop + ); + + CollectiveEpilogue epilogue; + epilogue(typename CollectiveMainloop::TileShapePV{}, blk_coord, + result, typename CollectiveMainloop::TiledMmaPV{}, + params.problem_size, params.epilogue, + epi_load_pipeline, storage.epilogue); +#endif + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4b96d71135f7396af3f9295f8fcaa3d4528fc381 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp @@ -0,0 +1,422 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/arch/arch.h" + +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template< + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class... Options +> +struct FmhaKernelTmaWarpSpecialized { + + // Options + static constexpr bool kIsEpilogueLocked = find_option_t::value; + static constexpr bool kLoadsQSeparately = find_option_t::value; + + + static const int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = CollectiveMainloop::NumMmaWarpGroups; + + using TileShape = typename CollectiveMainloop::TileShape; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using MainloopPipelineOuter = typename CollectiveMainloop::MainloopPipelineQ; + using MainloopPipelineInner = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineReducer = cutlass::PipelineAsync<2>; + + static constexpr uint32_t StagesPerMathWarpGroup = 2; + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< + StagesPerMathWarpGroup, NumMmaWarpGroups>; + + struct TensorStorageStruct { + typename CollectiveMainloop::SharedStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups]; + }; + union TensorStorageUnion { + typename CollectiveMainloop::SharedStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups]; + }; + using TensorStorage = std::conditional_t; + + struct SharedStorage { + TensorStorage tensors; + + using PipelineStorageInner = typename MainloopPipelineInner::SharedStorage; + using PipelineStorageOuter = typename MainloopPipelineOuter::SharedStorage; + using PipelineStorageReducer = typename MainloopPipelineReducer::SharedStorage; + + alignas(16) PipelineStorageInner pipeline_storage_inner; + alignas(16) PipelineStorageOuter pipeline_storage_outer; + alignas(16) PipelineStorageReducer pipeline_storage_reducer; + + using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + + alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier; + + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + alignas(16) EpiLoadPipelineStorage epi_load; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using ProblemShape = cute::tuple; + + struct Arguments { + ProblemShape problem_size; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_size; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + using PipelineParamsInner = typename MainloopPipelineInner::Params; + using PipelineStateInner = typename cutlass::PipelineState; + using PipelineParamsOuter = typename MainloopPipelineOuter::Params; + using PipelineStateOuter = typename cutlass::PipelineState; + using PipelineParamsReducer = typename MainloopPipelineReducer::Params; + using PipelineStateReducer = typename cutlass::PipelineState; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = (NumMmaWarpGroups + NumLoadWarpGroups) * cutlass::NumThreadsPerWarpGroup; + using ArchTag = cutlass::arch::Sm90; + + static constexpr uint32_t LoadRegisterRequirement = 40 - 2 * 8; + static constexpr uint32_t TotalRegisterSupply = (64*1024 / MaxThreadsPerBlock / MinBlocksPerMultiprocessor / 8) * 8 * MaxThreadsPerBlock / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MmaRegisterRequirement = ((TotalRegisterSupply - LoadRegisterRequirement) / NumMmaWarpGroups / 8) * 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_size, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_size, + CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2, + Consumer2 = 3, + Consumer3 = 4, + }; + enum class ProducerWarpRole { + LoadKV = 1, + Reducer = 0, + MaybeLoadQ = 2, // is kLoadsQSeparately is true, this warp loads Q (otherwise warp 0 does it) + MainloopEpilogue = 3, + }; + + static constexpr ProducerWarpRole WarpRoleLoadQ = kLoadsQSeparately ? ProducerWarpRole::MaybeLoadQ : ProducerWarpRole::LoadKV; + + TileScheduler tile_scheduler{params.tile_scheduler}; + + // Shared memory. + auto& storage = *reinterpret_cast(smem); + + int lane_idx = cutlass::canonical_lane_idx(); + int warp_idx = cutlass::canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % cutlass::NumWarpsPerWarpGroup; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + auto warp_group_role = WarpGroupRole(warp_group_idx); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int consumer_warp_group_idx = warp_group_idx - (int) WarpGroupRole::Consumer0; + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + PipelineParamsOuter pipeline_params_outer; + pipeline_params_outer.transaction_bytes = CollectiveMainloop::kOuterLoadBytes; + pipeline_params_outer.is_leader = lane_predicate && (producer_warp_role == WarpRoleLoadQ); + pipeline_params_outer.num_consumers = cutlass::NumThreadsPerWarpGroup; + + PipelineParamsInner pipeline_params_inner; + pipeline_params_inner.transaction_bytes = CollectiveMainloop::kInnerLoadBytes; + pipeline_params_inner.is_leader = lane_predicate && (producer_warp_role == ProducerWarpRole::LoadKV); + pipeline_params_inner.num_consumers = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + + PipelineParamsReducer pipeline_params_reducer; + pipeline_params_reducer.producer_arv_count = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + pipeline_params_reducer.consumer_arv_count = cutlass::NumThreadsPerWarp; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadKV) { + pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == WarpRoleLoadQ) { + pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Reducer) { + pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Consumer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1 || + warp_group_role == WarpGroupRole::Consumer2 || + warp_group_role == WarpGroupRole::Consumer3 + ) { + pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Consumer; + pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Consumer; + pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Producer; + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + + MainloopPipelineOuter pipeline_outer(storage.pipeline_storage_outer, pipeline_params_outer, Shape<_1, _1, _1>{}); + MainloopPipelineInner pipeline_inner(storage.pipeline_storage_inner, pipeline_params_inner, ClusterShape{}); + MainloopPipelineReducer pipeline_reducer(storage.pipeline_storage_reducer, pipeline_params_reducer); + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineStateInner smem_pipe_read_inner; + PipelineStateInner smem_pipe_write_inner = cutlass::make_producer_start_state(); + + PipelineStateOuter smem_pipe_read_outer; + PipelineStateOuter smem_pipe_write_outer = cutlass::make_producer_start_state(); + + PipelineStateReducer smem_pipe_read_reducer; + PipelineStateReducer smem_pipe_write_reducer = cutlass::make_producer_start_state(); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = consumer_warp_group_idx; + params_math_wg_order_barrier.group_size = cutlass::NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier(storage.math_wg_order, params_math_wg_order_barrier); + + // Epilogue Load pipeline + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params); + + if constexpr (kLoadsQSeparately) { + if ((warp_idx == 0) && lane_predicate) { + storage.load_warp_barrier.init(2 * cutlass::NumThreadsPerWarp); + } + cutlass::arch::fence_barrier_init(); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + // and to finish smem init + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } + else { + __syncthreads(); + } + + CollectiveMainloop collective_mainloop; + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc(); + if (producer_warp_role == ProducerWarpRole::LoadKV) { + bool do_barrier = kLoadsQSeparately; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + collective_mainloop.template load_kv_maybe_q( + block_rank_in_cluster, + blk_coord, params.mainloop, params.problem_size, + pipeline_inner, smem_pipe_write_inner, + pipeline_outer, smem_pipe_write_outer, + storage.tensors.mainloop, + storage.load_warp_barrier, do_barrier + ); + do_barrier = false; + } + } + else if (kLoadsQSeparately && (producer_warp_role == ProducerWarpRole::MaybeLoadQ)) { + bool do_barrier = true; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + collective_mainloop.load_maybe_q( + blk_coord, params.mainloop, params.problem_size, + pipeline_outer, smem_pipe_write_outer, + storage.tensors.mainloop, + storage.load_warp_barrier, do_barrier + ); + do_barrier = false; + } + } else if (producer_warp_role == ProducerWarpRole::Reducer) { + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + collective_mainloop.reduce( + blk_coord, params.mainloop, params.problem_size, + pipeline_reducer, smem_pipe_read_reducer, + storage.tensors.mainloop + ); + } + } + } + else if ( + warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1 || + warp_group_role == WarpGroupRole::Consumer2 || + warp_group_role == WarpGroupRole::Consumer3 + ) { + cutlass::arch::warpgroup_reg_alloc(); + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto wg_coord = blk_coord; + + constexpr int kOuterLoads = CollectiveMainloop::kOuterLoads; + + if (warp_group_role == WarpGroupRole::Consumer0) { + smem_pipe_read_outer.advance(0 * kOuterLoads); + } + else if (warp_group_role == WarpGroupRole::Consumer1) { + smem_pipe_read_outer.advance(1 * kOuterLoads); + } + else if (warp_group_role == WarpGroupRole::Consumer2) { + smem_pipe_read_outer.advance(2 * kOuterLoads); + } + else if (warp_group_role == WarpGroupRole::Consumer3) { + smem_pipe_read_outer.advance(3 * kOuterLoads); + } + + constexpr int wg_dim = is_constant<0, decltype(get<1>(wg_coord))>::value ? 0 : 1; + auto& wg_block = get(wg_coord); + if (warp_group_role == WarpGroupRole::Consumer0) { + wg_block = NumMmaWarpGroups * wg_block + 0; + } + else if (warp_group_role == WarpGroupRole::Consumer1) { + wg_block = NumMmaWarpGroups * wg_block + 1; + } + else if (warp_group_role == WarpGroupRole::Consumer2) { + wg_block = NumMmaWarpGroups * wg_block + 2; + } + else if (warp_group_role == WarpGroupRole::Consumer3) { + wg_block = NumMmaWarpGroups * wg_block + 3; + } + + auto result = collective_mainloop.compute( + blk_coord, wg_coord, + params.mainloop, params.problem_size, + pipeline_inner, smem_pipe_read_inner, + pipeline_outer, smem_pipe_read_outer, + pipeline_reducer, smem_pipe_write_reducer, + storage.tensors.mainloop, + math_wg_order_barrier + ); + + if (warp_group_role == WarpGroupRole::Consumer0) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 0)); + } + if constexpr (NumMmaWarpGroups >= 2) { + if (warp_group_role == WarpGroupRole::Consumer1) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 1)); + } + } + if constexpr (NumMmaWarpGroups >= 3) { + if (warp_group_role == WarpGroupRole::Consumer2) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 2)); + } + } + if constexpr (NumMmaWarpGroups >= 4) { + if (warp_group_role == WarpGroupRole::Consumer3) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 3)); + } + } + + if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.wait(); + + CollectiveEpilogue epilogue; + epilogue(typename CollectiveMainloop::TileShapePV{}, wg_coord, + result, typename CollectiveMainloop::TiledMmaPV{}, + params.problem_size, params.epilogue, + epi_load_pipeline, storage.tensors.epilogue[consumer_warp_group_idx]); + + if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive(); + } + } +#endif + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_options.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_options.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6746bf797c5905cbc005c787852fab6a0c7f1a09 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_options.hpp @@ -0,0 +1,83 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::kernel { + +template +struct find_option; + +template +struct find_option { + using option_value = Default; +}; + +template +struct find_option : + std::conditional_t< + Option::tag == kTag, + Option, + find_option + > +{}; + +template +using find_option_t = typename find_option::option_value; + +enum class Tag { + kIsPersistent, + kNumMmaWarpGroups, + kLoadsQSeparately, + + kIsMainloopLocked, + kIsEpilogueLocked, + + kStagesQ, + kStagesKV, + + kEpilogueKind, + + kBlocksPerSM, + kClusterM, + + kAccQK +}; + +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0b46e27e142973e675a9824a4eec4e784f1615d7 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp @@ -0,0 +1,204 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct IndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + IndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) + { + using namespace cute; + dim3 grid(round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<0>(problem_size), size<1>(problem_size)); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + } + + CUTLASS_DEVICE + IndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct PersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_h; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) + { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<0>(problem_size) * size<1>(problem_size); + + return Params { + num_blocks, + { num_m_blocks}, { size<0>(problem_size) }, { size<1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_h(block_decode, bidh, block_decode); + return make_coord(m_block, _0{}, make_coord(bidb, bidh)); + } + + CUTLASS_DEVICE + PersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct TileSchedulerBwdAdapter { + + using Params = typename Base::Params; + + Base base_; + + CUTLASS_DEVICE + TileSchedulerBwdAdapter(Params const& params) : base_(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) + { + using namespace cute; + return Base::to_underlying_arguments(select<0,1,3,2,4>(problem_size), hw_info, select<1,0,2>(cluster_shape), select<1,0,2>(tile_shape)); + } + + static dim3 get_grid_shape(Params const& params) { + return Base::get_grid_shape(params); + } + + CUTLASS_DEVICE + bool is_valid() { + return base_.is_valid(); + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return select<1,0,2>(base_.get_block_coord()); + } + + CUTLASS_DEVICE + TileSchedulerBwdAdapter& operator++() { + ++base_; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..503fb3aad4976fb02d80312a3ad7e53300c80e6c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp @@ -0,0 +1,357 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, /* class TensorDK, class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dQ_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */ + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) { + + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + ElementAccumulator acc_qk = 0; + ElementAccumulator acc_dov = 0; + ElementAccumulator acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_K] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + acc += mS[idx_K] * mK(idx_K, idx_D, idx_L); + } + mDQ(idx_Q, idx_D, idx_L) = acc; + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, */ class TensorDK, /* class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dK_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */ + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) { + + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + ElementAccumulator acc_qk = 0; + ElementAccumulator acc_dov = 0; + ElementAccumulator acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { + acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L); + } + mDK(idx_K, idx_D, idx_L) = acc; + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, class TensorDK, */ class TensorDV, + class Fusion +> +void __global__ fmha_bwd_reference_dV_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV, + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) { + + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + ElementAccumulator acc_qk = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L))); + } + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { + acc += mS[idx_Q] * mDO(idx_Q, idx_D, idx_L); + } + mDV(idx_K, idx_D, idx_L) = acc; + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dQ( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/ + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mDQ), size<2>(mDQ), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_bwd_reference_dQ_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_bwd_reference_dQ_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dK( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/ + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mDK), size<2>(mDK), 1); + dim3 block(256); + int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_bwd_reference_dK_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_bwd_reference_dK_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/ + class Fusion +> +void fmha_bwd_reference_dV( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/ + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mDV), size<2>(mDV), 1); + dim3 block(256); + int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_bwd_reference_dV_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_bwd_reference_dV_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, class TensorDK, class TensorDV, + class Fusion +> +void fmha_bwd_reference( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, TensorDK mDK, TensorDV mDV, + Fusion fusion +) { + fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); + fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); + fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_reference.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..94f443713d6ccbbd05e36d0ace33f8534fd8db22 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_reference.hpp @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Fusion +> +void __global__ fmha_reference_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAccumulator softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + auto id = make_identity_tensor(make_shape(1, 1)); + for (int idx_L = blockIdx.y; idx_L < size<2>(mO); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(mO); idx_Q += gridDim.x) { + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_D = 0; idx_D < size<1>(mK); idx_D++) { + acc += mQ(idx_Q, idx_D, idx_L) * mK(idx_K, idx_D, idx_L); + } + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + mS[idx_K] = static_cast(frag(0) * softmax_scale); + } + + __syncthreads(); + + ElementAccumulator maxS = -std::numeric_limits::infinity(); + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + maxS = std::max(maxS, mS[idx_K]); + } + if (maxS == -std::numeric_limits::infinity()) maxS = 0; + + __syncthreads(); + + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + mS[idx_K] = static_cast(exp(mS[idx_K] - maxS)); + } + + __syncthreads(); + + ElementAccumulator sum = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + sum += mS[idx_K]; + } + + Element scale = static_cast(1.0 / sum); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mO); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + acc += mS[idx_K] * mV(idx_K, idx_D, idx_L) * scale; + } + mO(idx_Q, idx_D, idx_L) = static_cast(acc); + } + + if (threadIdx.x == 0) { + mLSE(idx_Q, idx_L) = log(sum) + maxS; + } + + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Fusion +> +void fmha_reference( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mO), size<2>(mO), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_reference_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_reference_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/reference_abs_error.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/reference_abs_error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5eb2413afe1db1c5ff87d1bda32c9815c51d39e4 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/reference_abs_error.hpp @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include +#include "cutlass/util/device_memory.h" + +template +__global__ void reference_abs_diff_kernel( + Element* data, Element* data_ref, size_t count, + double* max_diff, double* sum_diff, + bool print_diff +) { + double thread_max_diff = 0; + double thread_sum_diff = 0; + + __shared__ double block_max_diff; + __shared__ double block_sum_diff; + + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + double diff = fabs(data[i] - data_ref[i]); + if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast(i), diff, (double)data[i], (double)data_ref[i]); + thread_max_diff = fmax(diff, thread_max_diff); + thread_sum_diff += diff; + } + + for (int i = 0; i < blockDim.x; i++) { + if (i == threadIdx.x) { + if (i == 0) { + block_max_diff = thread_max_diff; + block_sum_diff = thread_sum_diff; + } else { + block_max_diff = fmax(block_max_diff, thread_max_diff); + block_sum_diff += thread_sum_diff; + } + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicAdd(sum_diff, block_sum_diff); + + for (;;) { + unsigned long long prev = *reinterpret_cast(max_diff); + double prev_diff = reinterpret_cast(prev); + double new_max_diff = fmax(block_max_diff, prev_diff); + unsigned long long found = atomicCAS(reinterpret_cast(max_diff), prev, reinterpret_cast(new_max_diff)); + if (found == prev) break; + } + } +} + +template +void reference_abs_diff( + cutlass::DeviceAllocation const& data, + cutlass::DeviceAllocation const& data_ref, + double& max_diff, double& mean_diff +) { + static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1; + + cutlass::DeviceAllocation result; + result.reset(2); + assert(data.size() == data_ref.size()); + + cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double)); + if (err != cudaSuccess) { + std::cerr << "Memset failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + dim3 block(256, 1, 1); + dim3 grid(1024, 1, 1); + reference_abs_diff_kernel<<>>( + data.get(), data_ref.get(), data.size(), + result.get(), result.get() + 1, kPrintDiff); + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "Difference kernel failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + double result_host[2]; + err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault); + if (err != cudaSuccess) { + std::cerr << "Copy failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + max_diff = result_host[0]; + mean_diff = result_host[1] / static_cast(data.size()); +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/common/dist_gemm_helpers.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/common/dist_gemm_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..35f05442cdfe6d445f3a91f0891e27ff7cc4d330 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/common/dist_gemm_helpers.h @@ -0,0 +1,164 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Benchmark helpers for Distributed GEMM + + A delay kernel to gate all GEMMs across devices, controlled by a flag that + the host will set off once it launches DistGEMM across all devices. + + DistGpuTimer extends cutlass's existing cudaEvent-based timer to multiple devices. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include +#include +#include CUDA_STD_HEADER(atomic) + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Delay kernel +///////////////////////////////////////////////////////////////////////////////////////////////// + +using AtomicBoolean = cuda::atomic; + +__global__ void delay_kernel(const AtomicBoolean* atomic_flag_ptr) { + while (not atomic_flag_ptr->load()) { + __nanosleep(40); + } +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Distributed GPU Timer +/// Sets up cuda events for multiple processors. +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct DistGpuTimer { + int _primary_device; + cudaEvent_t _start[NP]; + cudaEvent_t _stop[NP]; + + /// Constructor + DistGpuTimer() + { + CUDA_CHECK(cudaGetDevice(&_primary_device)); + for (int device = 0; device < NP; ++device) { + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK(cudaEventCreate(&_start[device])); + CUDA_CHECK(cudaEventCreate(&_stop[device])); + } + CUDA_CHECK(cudaSetDevice(_primary_device)); + } + + /// Destructor + ~DistGpuTimer() + { + for (int device = 0; device < NP; ++device) { + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK(cudaEventDestroy(_start[device])); + CUDA_CHECK(cudaEventDestroy(_stop[device])); + } + CUDA_CHECK(cudaSetDevice(_primary_device)); + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(int device, cudaStream_t stream) { + assert(device >= 0 && device < NP); + CUDA_CHECK(cudaEventRecord(_start[device], stream)); + } + + /// Stop the timer + void stop(int device, cudaStream_t stream) { + assert(device >= 0 && device < NP); + CUDA_CHECK(cudaEventRecord(_stop[device], stream)); + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis(int device) { + assert(device >= 0 && device < NP); + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop[device])); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start[device], _stop[device])); + return elapsed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Generic device-to-device data movement kernel based for CuTe tensors. +/// +/// NOTE: this kernel assigns one element copy to every thread, and is by no means +/// an efficient way of copying tensors. It should only be used for convenience in +/// reference checks. +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +void device_copy(TensorSource tensor_source, + TensorDestination tensor_destination, + cudaStream_t stream); + + +template +__global__ void device_copy_kernel(TensorSource const tensor_source, + TensorDestination tensor_destination) { + auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + using ElementSrc = typename TensorSource::value_type; + using ElementDst = typename TensorDestination::value_type; + NumericConverter converter; + if (linear_idx < size(tensor_source)) { + tensor_destination(linear_idx) = converter(tensor_source(linear_idx)); + } +} + +template +void device_copy(TensorSource tensor_source, + TensorDestination tensor_destination, + cudaStream_t stream) { + + assert(tensor_source.size() == tensor_destination.size()); + + auto numel = tensor_source.size(); + static constexpr int NumThreads = 128; + auto grid_size = cute::ceil_div(numel, NumThreads); + + dim3 grid(grid_size); + dim3 block(NumThreads); + device_copy_kernel<<>>(tensor_source, tensor_destination); +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/common/gather_tensor.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/common/gather_tensor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..46fb64009735e4459ac6927e999921474f1ccd2d --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/common/gather_tensor.hpp @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" + +namespace example { + +using namespace cute; + +// Empty type used to disable gather/scatter for a GEMM argument +struct NoGather +{ + template + NoGather(Ts...) {}; +}; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr + IndexedGather(Index const *indices = {}): indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr + Index + operator()(I i) const { return indices_[i]; } + + CUTE_HOST_DEVICE friend + void + print(IndexedGather const &s) { + cute::print("Indexed"); + } + + Index const *indices_; +}; + +/// Function object that applies a stride to its argument +/// Example: StridedFunc gathers every other row/column +template +struct StridedGather +{ + CUTE_HOST_DEVICE constexpr + StridedGather(Stride stride = {}): stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(I i) const { return i * stride_; } + + CUTE_HOST_DEVICE friend + void + print(StridedGather const &s) { + cute::print("Strided{"); + print(s.stride_); + cute::print("}"); + } + + Stride stride_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr + CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; } + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; } + + CUTE_HOST_DEVICE friend + void + print(CustomStride const & s) { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend + auto + safe_div(CustomStride const &s, Div const &div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend + auto + make_layout(Shape const &shape, CustomStride const &stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE +auto +make_custom_stride_layout(Stride const &stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit stride with a custom gather stride + auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout(repeat_like(stride, _1{}), + replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE +auto +make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func) +{ + if constexpr (not cutlass::platform::is_same, NoGather>::value) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); + } else { + return make_tensor(iter, shape, stride); + } +} + +} // namespace example + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(ceil_div(shape, Int{}), ceil_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Offset,Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace example diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/common/helper.h b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/common/helper.h new file mode 100644 index 0000000000000000000000000000000000000000..aec1c7197c6e8140f3bceb00e891ca5a94cc6d92 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/common/helper.h @@ -0,0 +1,108 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cuda_runtime.h" +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU stream + */ +struct GpuTimer +{ + cudaStream_t _stream_id; + cudaEvent_t _start; + cudaEvent_t _stop; + + /// Constructor + GpuTimer() : _stream_id(0) + { + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); + } + + /// Destructor + ~GpuTimer() + { + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) + { + _stream_id = stream_id; + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); + } + + /// Stop the timer + void stop() + { + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() + { + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; + } +}; diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/cute/tutorial/blackwell/example_utils.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/cute/tutorial/blackwell/example_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f6332002bb77b30728071de83cae689ce5e68594 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/cute/tutorial/blackwell/example_utils.hpp @@ -0,0 +1,105 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include // CuTe tensor implementation +#include + +template +void +reference_gemm(TensorA const& tensor_A, TensorB const& tensor_B, + TensorC const& tensor_C, TensorD & tensor_D, + Alpha alpha, Beta beta) +{ + using namespace cute; + for (int m = 0; m < size<0>(tensor_D); ++m) { + for (int n = 0; n < size<1>(tensor_D); ++n) { + AccType c = AccType(0.f); + for (int k = 0; k < size<1>(tensor_A); ++k) { + c += tensor_A(m,k) * tensor_B(n,k); + } + tensor_D(m,n) = alpha * c + beta * tensor_C(m,n); + } + } +} + +template +bool +compare_results(TensorA const& tensor_A, TensorB const& tensor_B, + TensorC const& tensor_C, TensorD const& tensor_D, + RefTensorD const& ref_tensor_D, + bool print_diff = false) +{ + using namespace cute; + auto norm_A = matrix_inf_norm(tensor_A); + auto norm_B = matrix_inf_norm(tensor_B); + auto norm_C = matrix_inf_norm(tensor_C); + auto norm_D = matrix_inf_norm(tensor_D); + auto norm_ref_D = matrix_inf_norm(ref_tensor_D); + auto norm_diff = matrix_diff_inf_norm(tensor_D, ref_tensor_D); + + if (print_diff) { + for (int m = 0; m < size<0>(tensor_D); ++m) { + for (int n = 0; n < size<1>(tensor_D); ++n) { + std::cout << m << "," << n << " : " << tensor_D(m,n) << " vs. " << ref_tensor_D(m,n) << std::endl; + } + } + } + + std::cout << "norm (A) : " << norm_A.inf_norm << std::endl; + std::cout << "norm (B) : " << norm_B.inf_norm << std::endl; + std::cout << "norm (C) : " << norm_C.inf_norm << std::endl; + std::cout << "norm (D) : " << norm_D.inf_norm << std::endl; + std::cout << "norm (ref_D) : " << norm_ref_D.inf_norm << std::endl; + std::cout << "norm (D-ref_D) : " << norm_diff.inf_norm << std::endl; + + return (!norm_A.found_nan) && (!norm_B.found_nan) && + (!norm_C.found_nan) && (!norm_D.found_nan) && (!norm_ref_D.found_nan) && // There are no NaNs + (norm_A.inf_norm > 0.0) && (norm_B.inf_norm > 0.0) && + (norm_C.inf_norm > 0.0) && (norm_D.inf_norm > 0.0) && (norm_ref_D.inf_norm > 0.0) && // Values in tensors aren't zeros + (norm_diff.inf_norm <= 0.0); // Diff (ref_D-D) == 0 +} + +template +void +initialize_tensor(Tensor& tensor, cute::tuple value_range = {-2, 2}) +{ + using DataType = typename Tensor::element_type; + auto [min, max] = value_range; + for (int i = 0; i < cute::size(tensor); i++) { + tensor(i) = DataType(int((max-min)*(rand() / double(RAND_MAX)) + min)); + } +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py new file mode 100644 index 0000000000000000000000000000000000000000..79a23079077e8da6971e883c592def544aaa3579 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import os +from typing import Tuple +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import make_ptr + + +""" +An Example demonstrating how to call off-the-shelf kernel by-passing dlpack protocol + +The example shows how to directly pass pointers from PyTorch tensors to off-the-shelf kernels +written by CuTe DSL with a thin customized wrapper jit function. The jit function will be +compiled with inline without introducing overhead. + +To run this example: + +.. code-block:: bash + + python examples/ampere/call_bypass_dlpack.py + + +It's worth to mention that by-passing dlpack protocol can resolve the issue that dlpack doesn't handle shape-1 +mode correctly. For example, the following code will fail, because dlpack will convert the shape-1 mode +with stride-1 which propagate alignment incorrectly. + +.. code-block:: python + + @cute.kernel + def fails_kernel(gX: cute.Tensor): + bidx, _, _ = cute.arch.block_idx() + mX = gX[None, bidx, None] # We wish to retain alignment + # assert mX.iterator.alignment == 16 + + + @cute.jit + def fails(gX_: cute.Tensor): + gX = gX_ + fails_kernel(gX).launch(grid=(1, 1, 1), block=(128, 1, 1)) + + + gX_torch = torch.rand((128, 1, 128), device="cuda", dtype=torch.bfloat16) + fails(from_dlpack(gX_torch, assumed_align=16)) + +""" + +# Add the current directory to sys.path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from tensorop_gemm import TensorOpGemm + + +@cute.jit +def tensor_op_gemm_wrapper( + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + c_ptr: cute.Pointer, + m: cutlass.Int32, + n: cutlass.Int32, + k: cutlass.Int32, + l: cutlass.Int32, +): + print(f"\n[DSL INFO] Input Parameters:") + print(f"[DSL INFO] mnkl: {(m, n, k, l)}") + + # Assume alignment of shape to call tensorop_gemm example + m = cute.assume(m, divby=8) + n = cute.assume(n, divby=8) + + # Torch is row major + a_layout = cute.make_ordered_layout((m, k, l), order=(0, 1, 2)) + b_layout = cute.make_ordered_layout((n, k, l), order=(0, 1, 2)) + c_layout = cute.make_ordered_layout((m, n, l), order=(1, 0, 2)) + mA = cute.make_tensor(a_ptr, layout=a_layout) + mB = cute.make_tensor(b_ptr, layout=b_layout) + mC = cute.make_tensor(c_ptr, layout=c_layout) + + print(f"[DSL INFO] mA: {mA}") + print(f"[DSL INFO] mB: {mB}") + print(f"[DSL INFO] mC: {mC}") + + tensor_op_gemm = TensorOpGemm( + a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1) + ) + print(f"\n[DSL INFO] Created TensorOpGemm instance") + print(f"[DSL INFO] Input dtype: {a_ptr.value_type}") + print(f"[DSL INFO] Output dtype: {c_ptr.value_type}") + print(f"[DSL INFO] Accumulation dtype: {cutlass.Float32}") + print(f"[DSL INFO] Atom layout: {(2, 2, 1)}") + + # No need to compile inside jit function + tensor_op_gemm(mA, mB, mC) + print(f"\n[DSL INFO] Executed TensorOpGemm") + + +def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): + print(f"\nRunning TensorOpGemm test with:") + print(f"Tensor dimensions: {mnkl}") + + # (M,K,L) + a = torch.randn( + mnkl[3], mnkl[2], mnkl[0], dtype=torch.float16, device="cuda" + ).permute(2, 1, 0) + # (N,K,L) + b = torch.randn( + mnkl[3], mnkl[2], mnkl[1], dtype=torch.float16, device="cuda" + ).permute(2, 1, 0) + # (N,M,L) + c = torch.randn( + mnkl[3], mnkl[0], mnkl[1], dtype=torch.float16, device="cuda" + ).permute(1, 2, 0) + + print(f"Input tensor shapes:") + print(f"a: {a.shape}, dtype: {a.dtype}") + print(f"b: {b.shape}, dtype: {b.dtype}") + print(f"c: {c.shape}, dtype: {c.dtype}\n") + + a_ptr = make_ptr( + cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + b_ptr = make_ptr( + cutlass.Float16, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + c_ptr = make_ptr( + cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, *mnkl) + torch.cuda.synchronize() + + ref = torch.einsum("mkl,nkl->mnl", a, b) + torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05) + print(f"\n[DSL INFO] Results verified successfully!") + print(f"First few elements of result: \n{c[:3, :3, :3]}") + + +if __name__ == "__main__": + run_tensor_op_gemm_wrapper((512, 256, 128, 16)) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_from_jit.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_from_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..ee71e53fe87931a2ddf0fed944991a3ce313a29e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_from_jit.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Demonstrating JIT GEMM Implementation with Static Shape Wrapper + +This example illustrates how to invoke a JIT-compiled GEMM implementation through a wrapper function +with static shapes. It showcases the integration between PyTorch and CuTe tensors in a JIT context. + +Key features demonstrated: +1. Seamless conversion between PyTorch and CuTe tensors using the JitArgument protocol +2. Integration of static shape GEMM operations within a JIT-compiled wrapper function + +Core components: +- BufferWithLayout: Handles memory buffer management with configurable stride ordering +- tensor_op_gemm_wrapper: JIT-compiled entry point that orchestrates the GEMM operation + +Usage: + +.. code-block:: bash + + python examples/ampere/call_from_jit.py + +Default configuration: +- Batch dimension (L): 16 +- Matrix dimensions: M=512, N=256, K=128 +- Precision: Float16 inputs with Float32 accumulation + +Requirements: +- CUDA-capable GPU +- PyTorch with CUDA support +""" + +import os +import sys +from typing import Type, Tuple + +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.torch import dtype as torch_dtype +from cutlass.cute.runtime import make_ptr + + +# Add the current directory to sys.path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from tensorop_gemm import TensorOpGemm + + +class BufferWithLayout: + def __init__(self, ptr: cute.Pointer, stride_order: tuple[int, int, int]): + self.ptr = ptr + + # static properties + self.stride_order = stride_order + + def to_tensor( + self, shape: tuple[int, int, int], *, loc=None, ip=None + ) -> cute.Tensor: + assert len(shape) == len(self.stride_order), ( + f"Shape {shape} and stride_order {self.stride_order} must have the " + "same rank." + ) + layout = cute.make_ordered_layout(shape, self.stride_order) + # permute (l, mn, k) -> (mn, k, l) + res = cute.make_tensor(self.ptr, cute.select(layout, mode=[1, 2, 0])) + return res + + # Implement JitArgument Protocol and DynamicExpression Protocol + + def __c_pointers__(self): + """Get the C pointers for the underlying pointer. + + This method is part of the JitArgument Protocol and returns the C pointers + from the underlying pointer object. + + This is required for user to define a custom data type which can pass to JIT function. + When JIT compiled function is called, JIT executor will call this method to get raw pointers + to underlying data object. + + Following condition must be satisfied: + + len(__c_pointers__()) == len(__get_mlir_types__()) == len(__extract_mlir_values__()) + + :return: The C pointers from the underlying pointer object + :rtype: Any + """ + return self.ptr.__c_pointers__() + + def __get_mlir_types__(self): + """Get the MLIR types for the underlying pointer. + + This method is part of the JitArgument Protocol and returns the MLIR types + used for compiler to generate code. It must match the type of the underlying pointers + returned by __c_pointers__(). + + :return: The MLIR types from the underlying pointer object + :rtype: Any + """ + return self.ptr.__get_mlir_types__() + + def __extract_mlir_values__(self): + """Extract MLIR values from the underlying pointer. + + This method is part of the DynamicExpression Protocol and extracts MLIR values + from the underlying pointer object. + + It is used by compiler to generate function call in MLIR to another JIT function. + It must match the types returned by __get_mlir_types__(). + + :return: The MLIR values extracted from the underlying pointer object + :rtype: Any + """ + return self.ptr.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new BufferWithLayout instance from MLIR values. + + This method is part of the JitArgument & DynamicExpression Protocol and creates a new + BufferWithLayout instance with pointer initialized from the given MLIR values. + + It is used by compiler to generate function body in MLIR called by JIT function. + It must match the types returned by __c_pointers__() and __get_mlir_types__(). + code generator takes function arguments and reconstructs python object which is legal + inside function body. + + :param values: MLIR values to initialize the underlying pointer + :type values: Any + :return: A new BufferWithLayout instance with pointer initialized from values + :rtype: BufferWithLayout + """ + return BufferWithLayout( + self.ptr.__new_from_mlir_values__(values), self.stride_order + ) + + +@cute.jit +def tensor_op_gemm_wrapper( + buffer_a: BufferWithLayout, + buffer_b: BufferWithLayout, + buffer_c: BufferWithLayout, + mnkl: cutlass.Constexpr[tuple[int, int, int, int]], + acc_dtype: Type[cutlass.Numeric], + atom_layout_mnk: cutlass.Constexpr[tuple[int, int, int]], +): + print(f"\n[DSL INFO] Input Parameters:") + print(f"[DSL INFO] mnkl: {mnkl}") + print(f"[DSL INFO] buffer_a: {buffer_a}") + print(f"[DSL INFO] buffer_b: {buffer_b}") + print(f"[DSL INFO] buffer_c: {buffer_c}") + print(f"[DSL INFO] acc_dtype: {acc_dtype}") + print(f"[DSL INFO] atom_layout_mnk: {atom_layout_mnk}") + + mA = buffer_a.to_tensor(cute.select(mnkl, mode=[3, 0, 2])) + mB = buffer_b.to_tensor(cute.select(mnkl, mode=[3, 1, 2])) + mC = buffer_c.to_tensor(cute.select(mnkl, mode=[3, 0, 1])) + + print(f"\n[DSL INFO] Created Tensors:") + print(f"[DSL INFO] mA = {mA}") + print(f"[DSL INFO] mB = {mB}") + print(f"[DSL INFO] mC = {mC}") + + tensor_op_gemm = TensorOpGemm( + buffer_a.ptr.value_type, + buffer_c.ptr.value_type, + acc_dtype, + atom_layout_mnk, + ) + print(f"\n[DSL INFO] Created TensorOpGemm instance") + print(f"[DSL INFO] Input dtype: {buffer_a.ptr.value_type}") + print(f"[DSL INFO] Output dtype: {buffer_c.ptr.value_type}") + print(f"[DSL INFO] Accumulation dtype: {acc_dtype}") + print(f"[DSL INFO] Atom layout: {atom_layout_mnk}") + + # No need to compile inside jit function + tensor_op_gemm(mA, mB, mC) + print(f"\n[DSL INFO] Executed TensorOpGemm") + + +def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): + print(f"\nRunning TensorOpGemm test with:") + print(f"Tensor dimensions: {mnkl}") + + ab_dtype = cutlass.Float16 + c_dtype = cutlass.Float16 + + a = torch.randn( + mnkl[3], mnkl[0], mnkl[2], dtype=torch_dtype(ab_dtype), device="cuda" + ) + b = torch.randn( + mnkl[3], mnkl[1], mnkl[2], dtype=torch_dtype(ab_dtype), device="cuda" + ) + c = torch.randn( + mnkl[3], mnkl[0], mnkl[1], dtype=torch_dtype(c_dtype), device="cuda" + ) + + print(f"Input tensor shapes:") + print(f"a: {a.shape}, dtype: {a.dtype}") + print(f"b: {b.shape}, dtype: {b.dtype}") + print(f"c: {c.shape}, dtype: {c.dtype}\n") + + buffer_a = BufferWithLayout( + make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32), + (2, 1, 0), + ) + buffer_b = BufferWithLayout( + make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32), + (2, 1, 0), + ) + buffer_c = BufferWithLayout( + make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32), + (2, 1, 0), + ) + + tensor_op_gemm_wrapper( + buffer_a, + buffer_b, + buffer_c, + mnkl, # pass shape as static value + # no stride passing + cutlass.Float32, + (2, 2, 1), + ) + torch.cuda.synchronize() + + ref = torch.einsum("lmk,lnk->lmn", a, b) + torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05) + print(f"\n[DSL INFO] Results verified successfully!") + print(f"First few elements of result: \n{c[:3, :3, :3]}") + + +if __name__ == "__main__": + run_tensor_op_gemm_wrapper((512, 256, 128, 16)) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/dynamic_smem_size.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/dynamic_smem_size.py new file mode 100644 index 0000000000000000000000000000000000000000..dba505487bf53c5a2d63334d340449a0f1927f0f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/dynamic_smem_size.py @@ -0,0 +1,114 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import cutlass.cute as cute +import cutlass + +""" +Example of automatic shared memory size computation for configuring kernel launch + +This example demonstrates how to let the DSL automatically set shared memory + size for a kernel launch rather explicitly configuring it at launch time, + provided that developers are using `SmemAllocator` for all allocations. + +Usage: + python dynamic_smem_size.py # Show auto inference +""" + + +@cute.struct +class SharedData: + """A struct to demonstrate shared memory allocation.""" + + values: cute.struct.MemRange[cutlass.Float32, 64] # 256 bytes + counter: cutlass.Int32 # 4 bytes + flag: cutlass.Int8 # 1 byte + + +@cute.kernel +def kernel(): + """ + Example kernel that allocates shared memory. + The total allocation will be automatically calculated when smem=None. + """ + allocator = cutlass.utils.SmemAllocator() + + # Allocate various types of shared memory + shared_data = allocator.allocate(SharedData) + raw_buffer = allocator.allocate(512, byte_alignment=64) + int_array = allocator.allocate_array(element_type=cutlass.Int32, num_elems=128) + tensor_smem = allocator.allocate_tensor( + element_type=cutlass.Float16, + layout=cute.make_layout((32, 16)), + byte_alignment=16, + swizzle=None, + ) + return + + +@cute.kernel +def kernel_no_smem(): + """ + Example kernel that does not allocates shared memory. + The total allocation will be automatically calculated as 0 when smem=None. + """ + tidx, _, _ = cute.arch.block_idx() + if tidx == 0: + cute.printf("Hello world") + return + + +if __name__ == "__main__": + # Initialize CUDA context + cutlass.cuda.initialize_cuda_context() + + print("Launching kernel with auto smem size. (launch config `smem=None`)") + + # Compile the example + @cute.jit + def launch_kernel1(): + k = kernel() + k.launch( + grid=(1, 1, 1), + block=(1, 1, 1), + ) + print(f"Kernel recorded internal smem usage: {k.smem_usage()}") + + @cute.jit + def launch_kernel2(): + k = kernel_no_smem() + k.launch( + grid=(1, 1, 1), + block=(1, 1, 1), + ) + print(f"Kernel recorded internal smem usage: {k.smem_usage()}") + + cute.compile(launch_kernel1) + cute.compile(launch_kernel2) + + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_add.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_add.py new file mode 100644 index 0000000000000000000000000000000000000000..596822ccec9fc5085ea4d6d2d9659992f7c7caab --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_add.py @@ -0,0 +1,403 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse +import time +from typing import Type + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack + +""" +An Elementwise Addition Example using CuTe DSL. + +This example kernel copies data from global memory to register memory (rmem), performs the elementwise +addition operation, and stores the result back to global memory. + +Primary goals of this example are to demonstrate how basic global memory copies can be expressed in +CuTe DSL and illustrate canonical partitioning patterns in CuTe. It also implements canonical +predication for tensors whose shape is not multiple of tile size to guard OOB reads. + +Thread-value (or TV) layouts are central to canonical partitioning patterns in CuTe. They provide a +mapping from thread and a thread's value to the set of coordinates within a tile that we have sliced +out from a data tensor. + +The input tensors are row-major layout, that leading dimension is the right most dimension. In order +to efficiently copy data from global memory, we must map threads contiguously on row dimension. + +Thread ID mapping to 2D coordinates with layout `(4,32):(32,1)`: + + +----+----+----+----+-----+----+ + | | 0 | 1 | 2 | ... | 31 | + +----+----+----+----+-----+----+ + | 0 | T0 | T1 | T2 | ... | T31| + +----+----+----+----+-----+----+ + | 1 |T32 |T33 |T34 | ... |T63 | + +----+----+----+----+-----+----+ + | 2 |T64 |T65 |T66 | ... |T95 | + +----+----+----+----+-----+----+ + | 3 |T96 |T97 |T98 | ... |T127| + +----+----+----+----+-----+----+ + +As Ampere GPU supports a maximum of 128bit per load/store instruction and each element is 32bit, we +can load 4 elements per instruction. Having additional contiguous values allows for vectorization +across threads (coalesced accesses) and is required for saturating the memory bandwidth. + +We use `(4,4):(4,1)` as the val layout in this example. Notice that the major mode is the same as +the major mode of the input tensor - without which vectorization would not be possible. + +If you already know the TV layout you want to use for your tiled copy, CuTe DSL provides utility +`cute.make_layout_tv` to build the tiled copy type around it and the atom of your choice. + +.. code-block:: python + + thr_layout = cute.make_layout((4, 32), stride=(32, 1)) + val_layout = cute.make_layout((4, 4), stride=(4, 1)) + tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout) + + # Tile input tensor to thread blocks: ((TileM,TileN),(RestM,RestN)) + gA = cute.zipped_divide(mA, tiler_mn) + +Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy_tv` utility, which +infers the tiler and tv layout for the tiled copy automatically, where `tiler` is the tile size per thread +block and `tv_layout` is the TV layout which maps thread index and inter-thread index of data array per +thread to logical coordinates of elements in input and output tensors. + +.. code-block:: python + + blkA = gA[((None, None), bidx)] # (TileM,TileN) + + copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type) + tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) + + # get slice of tiled_copy_A for current thread + thr_copy_A = tiled_copy_A.get_slice(tidx) + + # partition per thread block tensor as source of tiled copy + thrA = thr_copy_A.partition_S(blkA) + + # allocate fragment for gmem->rmem + frgA = cute.make_fragment_like(thrA) + + # copy data from global memory to register memory + cute.copy(copy_atom_load, thrA, frgA) + + +To run this example: + +.. code-block:: bash + + python examples/ampere/elementwise_add.py --M 3 --N 12 + python examples/ampere/elementwise_add.py --M 1024 --N 512 + python examples/ampere/elementwise_add.py --M 1024 --N 1024 --benchmark --warmup_iterations 2 --iterations 1000 + +To collect performance with NCU profiler: + +.. code-block:: bash + + # Don't iterate too many times when profiling with ncu + ncu python examples/ampere/elementwise_add.py --M 2048 --N 2048 --benchmark --iterations 10 --skip_ref_check +""" + + +@cute.kernel +def elementwise_add_kernel( + gA: cute.Tensor, + gB: cute.Tensor, + gC: cute.Tensor, + cC: cute.Tensor, # coordinate tensor + shape: cute.Shape, + thr_layout: cute.Layout, + val_layout: cute.Layout, +): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + # slice for CTAs + # logical id -> address + blk_coord = ((None, None), bidx) + blkA = gA[blk_coord] # (TileM,TileN) + blkB = gB[blk_coord] # (TileM,TileN) + blkC = gC[blk_coord] # (TileM,TileN) + blkCrd = cC[blk_coord] # (TileM, TileN) + + # Note: these prints only run at compile/jit time + print(f"[DSL INFO] Sliced Tensors per thread block:") + print(f"[DSL INFO] blkA = {blkA.type}") + print(f"[DSL INFO] blkB = {blkB.type}") + print(f"[DSL INFO] blkC = {blkC.type}") + print(f"[DSL INFO] blkCrd = {blkCrd.type}") + + # # declare the atoms which will be used later for memory copy + copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type) + copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type) + + tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) + tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) + tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout) + + thr_copy_A = tiled_copy_A.get_slice(tidx) + thr_copy_B = tiled_copy_B.get_slice(tidx) + thr_copy_C = tiled_copy_C.get_slice(tidx) + + thrA = thr_copy_A.partition_S(blkA) + thrB = thr_copy_B.partition_S(blkB) + thrC = thr_copy_C.partition_S(blkC) + + # allocate fragments for gmem->rmem + frgA = cute.make_fragment_like(thrA) + frgB = cute.make_fragment_like(thrB) + frgC = cute.make_fragment_like(thrC) + + thrCrd = thr_copy_C.partition_S(blkCrd) + frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean) + + print(f"[DSL INFO] Sliced Tensors per thread:") + print(f"[DSL INFO] thrA = {thrA.type}") + print(f"[DSL INFO] thrB = {thrB.type}") + print(f"[DSL INFO] thrC = {thrC.type}") + print(f"[DSL INFO] thrCrd = {thrCrd.type}") + + for i in range(0, cute.size(frgPred), 1): + val = cute.elem_less(thrCrd[i], shape) + frgPred[i] = val + + # Print per thread predicate mask + # if tidx == 0 and bidx == 0: + # cute.printf("block_dim = {}", cute.arch.grid_dim()) + # cute.printf("shape = {}", shape) + # cute.print_tensor(thrA) + # cute.print_tensor(thrB) + # cute.print_tensor(frgPred) + + ########################################################## + # Move data to reg address space + ########################################################## + + cute.copy(copy_atom_load, thrA, frgA, pred=frgPred) + cute.copy(copy_atom_load, thrB, frgB, pred=frgPred) + + # if tidx == 0 and bidx == 0: + # cute.print_tensor(frgA) + # cute.print_tensor(frgB) + + # Load data before use. The compiler will optimize the copy and load + # operations to convert some memory ld/st into register uses. + result = frgA.load() + frgB.load() + + # Save the results back to registers. Here we reuse b's registers. + frgC.store(result) + + # Copy the results back to c + cute.copy(copy_atom_store, frgC, thrC, pred=frgPred) + + +@cute.jit +def elementwise_add(mA, mB, mC, copy_bits: cutlass.Constexpr = 128): + dtype = mA.element_type + vector_size = copy_bits // dtype.width + + thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0)) + val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0)) + tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout) + + print(f"[DSL INFO] Input Tensors:") + print(f"[DSL INFO] mA = {mA.type}") + print(f"[DSL INFO] mB = {mB.type}") + + print(f"[DSL INFO] Tiling Parameters:") + print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block") + print(f"[DSL INFO] tv_layout = {tv_layout}") + + gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN)) + gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN)) + gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN)) + print(f"[DSL INFO] Tiled Tensors:") + print(f"[DSL INFO] gA = {gA.type}") + print(f"[DSL INFO] gB = {gB.type}") + print(f"[DSL INFO] gC = {gC.type}") + + idC = cute.make_identity_tensor(mC.shape) + cC = cute.zipped_divide(idC, tiler=tiler_mn) + print(f"[DSL INFO] coord tensor = {cC.type}") + + elementwise_add_kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch( + grid=[cute.size(gC, mode=[1]), 1, 1], + block=[cute.size(tv_layout, mode=[0]), 1, 1], + ) + + +def run_elementwise_add( + M, + N, + dtype: Type[cutlass.Numeric], + is_a_dynamic_layout=False, + is_b_dynamic_layout=False, + is_result_dynamic_layout=False, + skip_ref_check=False, + benchmark=True, + warmup_iterations=2, + iterations=200, +): + print(f"\nRunning Elementwise Add test with:") + print(f"Tensor dimensions: [{M}, {N}]") + print(f"Input and Output Data type: {dtype}") + + torch_dtype = cutlass_torch.dtype(dtype) + if dtype.is_integer: + a = torch.randint(0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype) + b = torch.randint(0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype) + else: + a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + + c = torch.zeros_like(a) + + print(f"Input tensor shapes:") + print(f"a: {a.shape}, dtype: {a.dtype}") + print(f"b: {b.shape}, dtype: {b.dtype}") + print(f"c: {c.shape}, dtype: {c.dtype}\n") + + if not is_a_dynamic_layout: + a_tensor = from_dlpack(a).mark_layout_dynamic() + else: + a_tensor = a + + if not is_b_dynamic_layout: + b_tensor = from_dlpack(b).mark_layout_dynamic() + else: + b_tensor = b + + if not is_result_dynamic_layout: + c_tensor = from_dlpack(c).mark_layout_dynamic() + else: + c_tensor = c + + print("Compiling kernel with cute.compile ...") + start_time = time.time() + compiled_func = cute.compile(elementwise_add, a_tensor, b_tensor, c_tensor) + compilation_time = time.time() - start_time + print(f"Compilation time: {compilation_time:.4f} seconds") + + print("Executing vector add kernel...") + + # Get current CUstream from torch + current_stream = cutlass_torch.current_stream() + + if not skip_ref_check: + compiled_func(a_tensor, b_tensor, c_tensor) + print("Verifying results...") + torch.testing.assert_close(a + b, c) + print("Results verified successfully!") + + if not benchmark: + return + + def generate_tensors(): + if dtype.is_integer: + a = torch.randint( + 0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype + ) + b = torch.randint( + 0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype + ) + else: + a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + + c = torch.zeros_like(a) + + if not is_a_dynamic_layout: + a_tensor = from_dlpack(a).mark_layout_dynamic() + else: + a_tensor = a + + if not is_b_dynamic_layout: + b_tensor = from_dlpack(b).mark_layout_dynamic() + else: + b_tensor = b + + if not is_result_dynamic_layout: + c_tensor = from_dlpack(c).mark_layout_dynamic() + else: + c_tensor = c + + return testing.JitArguments(a_tensor, b_tensor, c_tensor) + + avg_time_us = testing.benchmark( + compiled_func, + workspace_generator=generate_tensors, + workspace_count=10, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + # Print execution results + print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") + print( + f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s" + ) + print(f"First few elements of result: \n{c[:3, :3]}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels" + ) + parser.add_argument("--M", default=1024, type=int) + parser.add_argument("--N", default=1024, type=int) + parser.add_argument("--warmup_iterations", default=2, type=int) + parser.add_argument("--iterations", default=100, type=int) + parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument("--benchmark", action="store_true") + + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError(f"Ampere GPU is required to run this example!") + + run_elementwise_add( + args.M, + args.N, + dtype=cutlass.Float32, + is_a_dynamic_layout=True, + is_b_dynamic_layout=True, + is_result_dynamic_layout=True, + skip_ref_check=args.skip_ref_check, + benchmark=args.benchmark, + warmup_iterations=args.warmup_iterations, + iterations=args.iterations, + ) + print("\nPASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_apply.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..43c3b5fc76ce8bca016b43198194ca5f849b4a4a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_apply.py @@ -0,0 +1,393 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse +import operator +import time +from typing import Type, List + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack + +""" +An Elementwise Apply Example using CuTe DSL. + +This example kernel demonstrates the meta-programming capability of the CuTe DSL by allowing +customization of elementwise operations through lambda functions. The kernel copies data from +global memory to register memory (rmem), applies a user-defined operation to the elements, +and stores the result back to global memory. + +Primary goals of this example: +1. Demonstrate meta-programming capability by passing lambda functions to customize elementwise operations +2. Show how to apply different operations (add, multiply, etc.) using the same kernel structure +3. Illustrate how to parameterize CUDA kernels with operation types at compile time + +To run this example: + +.. code-block:: bash + + # Run with addition operation + python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op add + + # Run with multiplication operation + python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op mul + + # Run with subtraction operation + python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op sub + + # Benchmark performance + python examples/ampere/elementwise_apply.py --M 2048 --N 2048 --op add --benchmark --warmup_iterations 2 --iterations 10 + +The example demonstrates how to express complex CUDA kernels with customizable operations +while maintaining high performance through efficient memory access patterns. +""" + + +@cute.kernel +def elementwise_apply_kernel( + op: cutlass.Constexpr, + inputs: List[cute.Tensor], + gC: cute.Tensor, + cC: cute.Tensor, # coordinate tensor + shape: cute.Shape, + tv_layout: cute.Layout, # (tid, vid) -> logic coord +): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + # slice for CTAs + cta_coord = ((None, None), bidx) + # logical coord -> address + # Leverage the meta-programming capability of the DSL to slice the tensors for each input + # All for loops below on input tensors would be fully unrolled automatically at compile time + ctaInputs = [t[cta_coord] for t in inputs] # (TileM, TileN) + ctaC = gC[cta_coord] # (TileM, TileN) + ctaCrd = cC[cta_coord] # (TileM, TileN) + + print(f"[DSL INFO] Sliced Tensors per thread block:") + for i in cutlass.range_constexpr(len(ctaInputs)): + print(f"[DSL INFO] ctaInputs{i} = {ctaInputs[i].type}") + print(f"[DSL INFO] ctaC = {ctaC.type}") + print(f"[DSL INFO] ctaCrd = {ctaCrd.type}") + + # compose with CTA TV layout + # (tid, vid) -> address + tidfrgInputs = [cute.composition(t, tv_layout) for t in ctaInputs] + tidfrgC = cute.composition(ctaC, tv_layout) + tidfrgCrd = cute.composition(ctaCrd, tv_layout) + # print(f"{tv_layout = }") + # print(f"{tidfrgAB[0] = }") + + thr_coord = (tidx, (None, None)) + + # slice for threads + # vid -> address + thrInputs = [t[thr_coord] for t in tidfrgInputs] # (V) + thrC = tidfrgC[thr_coord] # (V) + thrCrd = tidfrgCrd[thr_coord] + + print(f"[DSL INFO] Sliced Tensors per thread:") + for i in cutlass.range_constexpr(len(thrInputs)): + print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}") + print(f"[DSL INFO] thrC = {thrC.type}") + print(f"[DSL INFO] thrCrd = {thrCrd.type}") + + # allocate fragments for gmem->rmem + frgInputs = [cute.make_fragment_like(t, t.element_type) for t in thrInputs] + frgC = cute.make_fragment_like(thrC, gC.element_type) + frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean) + + for i in cutlass.range(cute.size(frgPred), unroll=1): + frgPred[i] = cute.elem_less(thrCrd[i], shape) + + # if tidx == 0 and bidx == 0: + # cute.print_tensor(frgPred) + + ########################################################## + # Move data to reg address space + ########################################################## + + # declare the atoms which will be used later for memory copy + # Compile time validation: expect same element type for all input tensors so as to reuse the copy atom for load + assert all(t.element_type == inputs[0].element_type for t in inputs) + + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + inputs[0].element_type, + num_bits_per_copy=inputs[0].element_type.width, + ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gC.element_type, + num_bits_per_copy=gC.element_type.width, + ) + + for thrInput, frgInput in zip(thrInputs, frgInputs): + cute.copy(copy_atom_load, thrInput, frgInput, pred=frgPred) + + # Load data before use. The compiler will optimize the copy and load + # operations to convert some memory ld/st into register uses. + result = op(*[frgInput.load() for frgInput in frgInputs]) + + # Save the results back to registers. Here we reuse b's registers. + frgC.store(result) + + # Copy the results back to c + cute.copy(copy_atom_store, frgC, thrC, pred=frgPred) + + +@cute.jit +def elementwise_apply( + op: cutlass.Constexpr, + a: cute.Tensor, + b: cute.Tensor, + result: cute.Tensor, + stream: cuda.CUstream, +): + """CUDA kernel applying binary operator on each element of two n-D input tensors in + CuTe Python and store to result tensor. + + :param op: Binary operator or lambda function to apply element-wise + :type op: cutlass.Constexpr + :param a: First input tensor + :type a: cute.Tensor + :param b: Second input tensor + :type b: cute.Tensor + :param result: Output tensor to store the results of op(a, b) + :type result: cute.Tensor + :return: None + :rtype: None + + .. code-block:: python + + # Example 1: Adding two tensors + x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, device="cuda") + y = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32, device="cuda") + result = torch.empty_like(x) + elementwise_apply(operator.add, from_dlpack(x), from_dlpack(y), from_dlpack(result)) + # result: + # tensor([[6.0, 8.0], + # [10.0, 12.0]], device='cuda:0') + + # Example 2: Using a lambda function + elementwise_apply(lambda a, b: a * a + b * b, from_dlpack(x), from_dlpack(y), from_dlpack(result)) + # result: + # tensor([[ 2., 8.], + # [ 54., 512.]], device='cuda:0') + """ + + # Baseline: naive TV layout + # * mA layout: (4096, 4096):(4096, 1) + # * TV layout map to (512, 4) tile + # * tidx maps to mode-0 but input layout is contiguous on mode-1, performance will be bad + # tv_layout = cute.make_layout((128, (4, 4)), stride=(4, (512, 1))) + # cta_tiler = (512, 4) + + # Opt-1: better TV layout with better 1D thread layout (SOL with 1D thread layout) + # * mA layout: (4096, 4096):(4096, 1) + # * TV layout map to (4, 512) tile + # * tidx maps to mode-1 which is leading mode of input tensor for coalesced load + # tv_layout = cute.make_layout((128, (4, 4)), stride=(16, (4, 1))) + # cta_tiler = (4, 512) + + # Opt-2: 2D tile but worse + # * mA layout: (4096, 4096):(4096, 1) + # * TV layout map to (128, 16) logical tile + # * V layout is bad as contiguous mode is not on right-most + # * `cute.copy` only supports vectorize when stride-1 of v-layout on right-most ) + # tv_layout = cute.make_layout(((32, 4), (4, 4)), stride=((4, 512), (1, 128))) + # cta_tiler = (128, 16) + + # Opt-3: SOL with 2D thread tile + # * mA layout: (4096, 4096):(4096, 1) + # * TV layout map to (16, 128) logical tile + # * tidx maps to mode-1 and input layout is contiguous on mode-1 for coalesced load-store + thr_layout = cute.make_layout((4, 32), stride=(32, 1)) + val_layout = cute.make_layout((4, 4), stride=(4, 1)) + tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout) + + print(f"[DSL INFO] Input Tensors:") + print(f"[DSL INFO] a = {a.type}") + print(f"[DSL INFO] b = {b.type}") + print(f"[DSL INFO] result = {result.type}") + + print(f"[DSL INFO] Tiling Parameters:") + print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block") + print(f"[DSL INFO] tv_layout = {tv_layout}") + + gA = cute.zipped_divide(a, tiler_mn) # ((TileM, TileN), (RestM, RestN)) + gB = cute.zipped_divide(b, tiler_mn) # ((TileM, TileN), (RestM, RestN)) + gC = cute.zipped_divide(result, tiler_mn) # ((TileM, TileN), (RestM, RestN)) + + print(f"[DSL INFO] Tiled Tensors:") + print(f"[DSL INFO] gA = {gA.type}") + print(f"[DSL INFO] gB = {gB.type}") + print(f"[DSL INFO] gC = {gC.type}") + + idC = cute.make_identity_tensor(result.shape) + cC = cute.zipped_divide(idC, tiler=tiler_mn) + print(f"[DSL INFO] coord tensor = {cC.type}") + + # Launch the kernel asynchronously + # Async token(s) can also be specified as dependencies + elementwise_apply_kernel( + op, + [gA, gB], # Group input tensors into a list as a single argument + gC, + cC, + result.shape, + tv_layout, + ).launch( + grid=[cute.size(gC, mode=[1]), 1, 1], + block=[cute.size(tv_layout, mode=[0]), 1, 1], + stream=stream, + ) + + +def run_elementwise_apply_and_verify( + op, + M, + N, + dtype: Type[cutlass.Numeric], + skip_ref_check=False, + benchmark=True, + warmup_iterations=2, + iterations=100, +): + if not torch.cuda.is_available(): + raise RuntimeError(f"Ampere GPU is required to run this example!") + + # Create non default CUDA stream from PyTorch + torch_stream = torch.cuda.Stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + + print(f"\nRunning Elementwise Apply test with:") + print(f"Tensor dimensions: [{M}, {N}]") + print(f"Input and Output Data type: {dtype}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Measurement iterations: {iterations}\n") + + torch_dtype = cutlass_torch.dtype(dtype) + + # Allocate tensors with random values. + a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + c = torch.zeros_like(a) + + print(f"Input tensor shapes:") + print(f"a: {a.shape}, dtype: {a.dtype}") + print(f"b: {b.shape}, dtype: {b.dtype}") + print(f"c: {c.shape}, dtype: {c.dtype}\n") + + epsilon = 1.2 + if op in (operator.truediv, operator.floordiv): + b = torch.where(b == 0, torch.tensor(epsilon), b) + + print("Executing elementwise apply kernel...") + + if not skip_ref_check: + elementwise_apply( + op, + from_dlpack(a), + from_dlpack(b), + from_dlpack(c).mark_layout_dynamic(), + current_stream, + ) + print("Verifying results...") + torch.testing.assert_close(op(a, b), c) + print("Results verified successfully!") + + if not benchmark: + return + + compiled_func = cute.compile( + elementwise_apply, + op, + from_dlpack(a), + from_dlpack(b), + from_dlpack(c).mark_layout_dynamic(), + current_stream, + ) + + # When compiled we inlined op in the kernel, so we do not pass it when benchmarking + + avg_time_us = testing.benchmark( + compiled_func, + kernel_arguments=testing.JitArguments( + from_dlpack(a), + from_dlpack(b), + from_dlpack(c).mark_layout_dynamic(), + current_stream, + ), + warmup_iterations=warmup_iterations, + iterations=iterations, + use_cuda_graphs=True, + stream=current_stream, + ) + + avg_time = avg_time_us / 1e3 + + # Print execution results + print(f"Kernel execution time: {avg_time:.4f} ms") + print( + f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s" + ) + print(f"First few elements of result: \n{c[:3, :3]}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="example of elementwise apply to demonstrate building elementwise kernels" + ) + parser.add_argument("--M", default=128, type=int) + parser.add_argument("--N", default=128, type=int) + parser.add_argument("--op", default="add", type=str) + parser.add_argument("--warmup_iterations", default=2, type=int) + parser.add_argument("--iterations", default=100, type=int) + parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument("--benchmark", action="store_true") + args = parser.parse_args() + run_elementwise_apply_and_verify( + getattr(operator, args.op), + args.M, + args.N, + dtype=cutlass.Float32, + warmup_iterations=args.warmup_iterations, + iterations=args.iterations, + skip_ref_check=args.skip_ref_check, + benchmark=args.benchmark, + ) + print("\nPASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/flash_attention_v2.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/flash_attention_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f8ff4ccdc2949724133686984334411f8e6482 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/flash_attention_v2.py @@ -0,0 +1,1334 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from types import SimpleNamespace +from typing import Type, Union, Callable + +import torch +import cuda.bindings.driver as cuda +import cutlass.cute.testing as testing +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, warp +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack +import cutlass.utils as utils + +""" +A flash attention v2 forward pass example for NVIDIA Ampere SM80 architecture using CUTE DSL. + +- Matrix Q is BxSqxNxH, B is batch dimension, Sq is query sequence length, N is number of heads, H is head dimension +- Matrix K is BxSkxNxH, B is batch dimension, Sk is key sequence length, N is number of heads, H is head dimension +- Matrix V is BxSkxNxH, B is batch dimension, Sk is key sequence length, N is number of heads, H is head dimension +- Matrix O is BxSqxNxH, B is batch dimension, Sq is query sequence length, N is number of heads, H is head dimension + +This kernel supports the following features: + - Utilizes CpAsync for efficient memory operations + - Utilizes Ampere's tensor core for matrix multiply-accumulate (MMA) operations + - Utilizes register pipeline to overlap shared memory-to-register transfers with computations. + - Leverages DSL to implement an integrated online softmax fusion pattern. + +This kernel works as follows: +1. Load Q and K matrices from global memory (GMEM) to shared memory (SMEM) using CpAsync operations. +2. Perform matrix multiply-accumulate (MMA) operations using tensor core instructions to compute intermediate result S. +3. Apply padding mask or causal mask to S during initial iterations. +4. Apply online softmax to S and rescale O using results from previous iteration. +5. Load V matrices and perform matrix multiply-accumulate (MMA) operations to compute final result O. +6. Normalize O after all iterations complete and store result back to global memory (GMEM). + +To run this example: + +.. code-block:: bash + + python examples/ampere/flash_attention_v2.py \ + --dtype Float16 --head_dim 128 --m_block_size 128 --n_block_size 128 \ + --num_threads 128 --batch_size 1 --seqlen_q 1280 --seqlen_k 1536 \ + --num_head 16 --softmax_scale 1.0 --is_causal + +The above command configures the model to use float16 for inputs and outputs. The problem dimensions +are set to: batch size of 1, query sequence length of 1280, key sequence length of 1536, head dimension +of 128, and 16 attention heads. The softmax scale is set to 1.0 and causal masking is enabled. The computation +uses tiles of size 128x128 for m and n dimensions, and utilizes 128 parallel threads. + +To collect the performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/ampere/flash_attention_v2.py \ + --dtype Float16 --head_dim 128 --m_block_size 128 --n_block_size 128 \ + --num_threads 128 --batch_size 1 --seqlen_q 1280 --seqlen_k 1536 \ + --num_head 16 --softmax_scale 1.0 --is_causal --skip_ref_check + +There are some constraints for this example: +* Only fp16 and bf16 data types are supported. +* The contiguous dimension of each tensor must be at least 16 bytes aligned. +* The log-sum-exp(for training) is not computed in the kernel. +* The values of `m_block_size`, `n_block_size`, and `head_dim` must be selected to stay within shared memory capacity limits. +* `m_block_size * 2` must be divisible by `num_threads`, otherwise the kernel will not be able to get the correct result. +""" + + +class FlashAttentionForwardAmpere: + def __init__( + self, + head_dim: int, + m_block_size: int = 128, + n_block_size: int = 128, + num_threads: int = 128, + is_causal: bool = False, + ): + """Initializes the configuration for a flash attention v2 kernel. + + All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension + should be a multiple of 8. + + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param n_block_size: n block size + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + """ + self._head_dim = head_dim + self._m_block_size = m_block_size + self._n_block_size = n_block_size + # padding head_dim to a multiple of 32 as k_block_size + self._head_dim_padded = (head_dim + 31) // 32 * 32 + self._num_threads = num_threads + self._is_causal = is_causal + + @staticmethod + def can_implement( + dtype, head_dim, m_block_size, n_block_size, num_threads, is_causal + ) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param n_block_size: n block size + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + :type is_causal: bool + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + # Check if data type is fp16 or bf16 + if dtype != cutlass.Float16 and dtype != cutlass.BFloat16: + return False + + # Check if head dimension is a multiple of 8 + if head_dim % 8 != 0: + return False + + # Check if number of threads is a multiple of 32 + if num_threads % 32 != 0: + return False + + # Check if block size setting is out of shared memory capacity + # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size + smem_usage = (m_block_size * head_dim + n_block_size * head_dim * 2) * 2 + smem_capacity = utils.get_smem_capacity_in_bytes("sm_80") + if smem_usage > smem_capacity: + return False + + # Check if twice the block size is divisible by the number of threads + if (m_block_size * 2) % num_threads != 0: + return False + + return True + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + softmax_scale: cutlass.Float32, + stream: cuda.CUstream, + ): + """Configures and launches the flash attention v2 kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(seqlen_q * num_head * head_dim, num_head * head_dim, head_dim, 1) + + Prepares the shared memory layout, tiled copy atoms, tiled mma and shared memory storage. + Then launches the kernel function with the prepared parameters. + + :param mQ: query tensor + :type mQ: cute.Tensor + :param mK: key tensor + :type mK: cute.Tensor + :param mV: value tensor + :type mV: cute.Tensor + :param mO: output tensor + :type mO: cute.Tensor + :param softmax_scale: softmax scale + :type softmax_scale: cutlass.Float32 + """ + # Get the data type and check if it is fp16 or bf16 + if cutlass.const_expr( + not ( + mQ.element_type == mK.element_type == mV.element_type == mO.element_type + ) + ): + raise TypeError("All tensors must have the same data type") + if cutlass.const_expr( + not ( + mQ.element_type == cutlass.Float16 + or mQ.element_type == cutlass.BFloat16 + ) + ): + raise TypeError("Only Float16 or BFloat16 is supported") + self._dtype: Type[cutlass.Numeric] = mQ.element_type + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: Q/K/V + # /////////////////////////////////////////////////////////////////////////////// + smem_k_block_size = 64 if self._head_dim_padded % 64 == 0 else 32 + swizzle_bits = 3 if smem_k_block_size == 64 else 2 + sQ_layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 3), + 0, + cute.make_layout((8, smem_k_block_size), stride=(smem_k_block_size, 1)), + ) + sQ_layout = cute.tile_to_shape( + sQ_layout_atom, + (self._m_block_size, self._head_dim_padded), + (0, 1), + ) + + sKV_layout_atom = sQ_layout_atom + sKV_layout = cute.tile_to_shape( + sKV_layout_atom, + (self._n_block_size, self._head_dim_padded), + (0, 1), + ) + + sO_layout = sQ_layout + + @cute.struct + class SharedStorage: + sQ: cute.struct.Align[ + cute.struct.MemRange[self._dtype, cute.cosize(sQ_layout)], 1024 + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self._dtype, cute.cosize(sKV_layout)], 1024 + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self._dtype, cute.cosize(sKV_layout)], 1024 + ] + + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self._dtype.width + # atom_async_copy: async copy atom for QKV load + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self._dtype, + num_bits_per_copy=universal_copy_bits, + ) + # atom_universal_copy: universal copy atom for O store + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self._dtype, + num_bits_per_copy=universal_copy_bits, + ) + # tQKV_layout: thread layout for QKV load + tQKV_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems + tQKV_layout = cute.make_layout( + (self._num_threads // tQKV_shape_dim_1, tQKV_shape_dim_1), + stride=(tQKV_shape_dim_1, 1), + ) + # tO_layout: thread layout for O store + tO_layout = tQKV_layout + + # Value layouts for copies + vQKV_layout = cute.make_layout((1, async_copy_elems)) + vO_layout = vQKV_layout + + # gmem_tiled_copy_QKV: tiled copy for QKV load + gmem_tiled_copy_QKV = cute.make_tiled_copy_tv( + atom_async_copy, tQKV_layout, vQKV_layout + ) + # gmem_tiled_copy_O: tiled copy for O store + gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tO_layout, vO_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled mma + # /////////////////////////////////////////////////////////////////////////////// + tiled_mma = cute.make_tiled_mma( + warp.MmaF16BF16Op(self._dtype, cutlass.Float32, (16, 8, 16)), + (self._num_threads // 32, 1, 1), + permutation_mnk=(self._num_threads // 32 * 16, 16, 16), + ) + + # grid_dim: (m_block, batch_size, num_head) + grid_dim = ( + cute.ceil_div(mQ.shape[1], self._m_block_size), + cute.size(mQ.shape[0]), + cute.size(mQ.shape[2]), + ) + LOG2_E = 1.4426950408889634074 + softmax_scale_log2 = softmax_scale * LOG2_E + self.kernel( + mQ, + mK, + mV, + mO, + softmax_scale_log2, + sQ_layout, + sKV_layout, + sO_layout, + gmem_tiled_copy_QKV, + gmem_tiled_copy_O, + tiled_mma, + SharedStorage, + ).launch( + grid=grid_dim, + block=[self._num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + sQ_layout: cute.ComposedLayout, + sKV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + gmem_tiled_copy_QKV: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma: cute.TiledMma, + SharedStorage: cutlass.Constexpr, + ): + """Kernel function for flash attention v2. + + :param mQ: query tensor + :type mQ: cute.Tensor + :param mK: key tensor + :type mK: cute.Tensor + :param mV: value tensor + :type mV: cute.Tensor + :param mO: output tensor + :type mO: cute.Tensor + :param softmax_scale_log2: softmax scale log2 + :type softmax_scale_log2: cutlass.Float32 + :param sQ_layout: query layout + :type sQ_layout: cute.ComposedLayout + :param sKV_layout: key/value layout + :type sKV_layout: cute.ComposedLayout + :param sO_layout: output layout + :type sO_layout: cute.ComposedLayout + :param gmem_tiled_copy_QKV: tiled copy for QKV load + :type gmem_tiled_copy_QKV: cute.TiledCopy + :param gmem_tiled_copy_O: tiled copy for O store + :type gmem_tiled_copy_O: cute.TiledCopy + :param tiled_mma: tiled mma + :type tiled_mma: cute.TiledMma + :param SharedStorage: shared storage + :type SharedStorage: cutlass.Constexpr + """ + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + m_block, batch_size, num_head = cute.arch.block_idx() + + n_block_max = cute.ceil_div(mK.shape[1], self._n_block_size) + if self._is_causal: + n_block_max = min( + cute.ceil_div( + (m_block + 1) * self._m_block_size, + self._n_block_size, + ), + n_block_max, + ) + n_block = n_block_max - 1 + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + # (m_block_size, head_dim) + gQ = cute.local_tile( + mQ[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + # (n_block_size, head_dim, n_block) + gK = cute.local_tile( + mK[batch_size, None, num_head, None], + (self._n_block_size, self._head_dim_padded), + (None, 0), + ) + # (n_block_size, head_dim, n_block) + gV = cute.local_tile( + mV[batch_size, None, num_head, None], + (self._n_block_size, self._head_dim_padded), + (None, 0), + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + + storage = smem.allocate(SharedStorage) + sQ = storage.sQ.get_tensor(sQ_layout) + sK = storage.sK.get_tensor(sKV_layout) + sV = storage.sV.get_tensor(sKV_layout) + + # Transpose view of V to tensor with layout (head_dim, n_block_size) for tiled mma + sVt = cute.composition( + sV, + cute.make_layout( + (self._head_dim_padded, self._n_block_size), + stride=(self._n_block_size, 1), + ), + ) + + gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K) + tQgQ = gmem_thr_copy_QKV.partition_S(gQ) + tQsQ = gmem_thr_copy_QKV.partition_D(sQ) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tKgK = gmem_thr_copy_QKV.partition_S(gK) + tKsK = gmem_thr_copy_QKV.partition_D(sK) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tVgV = gmem_thr_copy_QKV.partition_S(gV) + tVsV = gmem_thr_copy_QKV.partition_D(sV) + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma = tiled_mma.get_slice(tidx) + tSrQ = thr_mma.make_fragment_A(thr_mma.partition_A(sQ)) + tSrK = thr_mma.make_fragment_B(thr_mma.partition_B(sK)) + tOrVt = thr_mma.make_fragment_B(thr_mma.partition_B(sVt)) + acc_shape_O = thr_mma.partition_shape_C( + (self._m_block_size, self._head_dim_padded) + ) + acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32) + acc_O.fill(0.0) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_Q = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self._dtype, + ) + smem_copy_atom_K = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self._dtype, + ) + smem_copy_atom_V = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + self._dtype, + ) + smem_tiled_copy_Q = cute.make_tiled_copy_A(smem_copy_atom_Q, tiled_mma) + smem_tiled_copy_K = cute.make_tiled_copy_B(smem_copy_atom_K, tiled_mma) + smem_tiled_copy_V = cute.make_tiled_copy_B(smem_copy_atom_V, tiled_mma) + + smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx) + smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx) + smem_thr_copy_V = smem_tiled_copy_V.get_slice(tidx) + + tSsQ = smem_thr_copy_Q.partition_S(sQ) + tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ) + tSsK = smem_thr_copy_K.partition_S(sK) + tSrK_copy_view = smem_thr_copy_K.retile(tSrK) + tOsVt = smem_thr_copy_V.partition_S(sVt) + tOrVt_copy_view = smem_thr_copy_V.retile(tOrVt) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for Q and KV + mcQ = cute.make_identity_tensor(mQ.layout.shape) + mcKV = cute.make_identity_tensor(mK.layout.shape) + cQ = cute.local_tile( + mcQ[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + cKV = cute.local_tile( + mcKV[batch_size, None, num_head, None], + (self._n_block_size, self._head_dim_padded), + (n_block, 0), + ) + + # Repeat the partitioning with identity layouts + tQcQ = gmem_thr_copy_QKV.partition_S(cQ) + tKVcKV = gmem_thr_copy_QKV.partition_S(cKV) + # Allocate predicate tensors for m and n, here we only allocate the tile of k, and do special process for mn. + # This is to reduce register pressure and gets 2-3% performance gain compared with allocating the whole tile. + tQpQ = cute.make_fragment( + cute.make_layout( + ( + tQsQ.shape[0][1], + cute.size(tQsQ, mode=[1]), + cute.size(tQsQ, mode=[2]), + ), + stride=(cute.size(tQsQ, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + tKVpKV = cute.make_fragment( + cute.make_layout( + ( + tKsK.shape[0][1], + cute.size(tKsK, mode=[1]), + cute.size(tKsK, mode=[2]), + ), + stride=(cute.size(tKsK, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + # Set predicates for head_dim bounds, seqlen_q/k bounds is processed at the first tile. + for rest_v in cutlass.range_constexpr(tQpQ.shape[0]): + for rest_k in cutlass.range_constexpr(tQpQ.shape[2]): + tQpQ[rest_v, 0, rest_k] = cute.elem_less( + tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3] + ) + for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]): + for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]): + tKVpKV[rest_v, 0, rest_k] = cute.elem_less( + tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3] + ) + # /////////////////////////////////////////////////////////////////////////////// + # Prefetch Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads of the last mn-tile, where we take care of the mn residue + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + if cute.elem_less(tQcQ[0, m, 0][1], mQ.layout.shape[1]): + cute.copy( + gmem_tiled_copy_QKV, + tQgQ[None, m, None], + tQsQ[None, m, None], + pred=tQpQ[None, m, None], + ) + else: + # Clear the smem tiles to account for predicated off loads + tQsQ[None, m, None].fill(0) + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): + if cute.elem_less(tKVcKV[0, n, 0][1], mK.layout.shape[1]): + cute.copy( + gmem_tiled_copy_QKV, + tKgK[None, n, None, n_block], + tKsK[None, n, None], + pred=tKVpKV[None, n, None], + ) + else: + # Clear the smem tiles to account for predicated off loads + tKsK[None, n, None].fill(0) + + cute.arch.cp_async_commit_group() + # /////////////////////////////////////////////////////////////////////////////// + # Softmax intermediate result: row_max and row_sum + # /////////////////////////////////////////////////////////////////////////////// + # shape: (atom_v_m * rest_m) + row_max = cute.make_fragment( + (acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32 + ) + # shape: (atom_v_m * rest_m) + row_sum = cute.make_fragment( + (acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32 + ) + row_max.fill(-cutlass.Float32.inf) + row_sum.fill(0.0) + + # group parameters for compute_one_n_block + basic_params = SimpleNamespace( + m_block=m_block, + n_block=n_block, + mQ=mQ, + mK=mK, + batch_size=batch_size, + num_head=num_head, + ) + mma_params = SimpleNamespace( + thr_mma=thr_mma, + tiled_mma=tiled_mma, + tSrQ=tSrQ, + tSrK=tSrK, + tOrVt=tOrVt, + acc_O=acc_O, + ) + gmem_copy_params = SimpleNamespace( + gmem_tiled_copy_QKV=gmem_tiled_copy_QKV, + tKVcKV=tKVcKV, + tKgK=tKgK, + tKsK=tKsK, + tVgV=tVgV, + tVsV=tVsV, + tKVpKV=tKVpKV, + ) + smem_copy_params = SimpleNamespace( + smem_tiled_copy_Q=smem_tiled_copy_Q, + smem_tiled_copy_K=smem_tiled_copy_K, + smem_tiled_copy_V=smem_tiled_copy_V, + tSsQ=tSsQ, + tSrQ_copy_view=tSrQ_copy_view, + tSsK=tSsK, + tSrK_copy_view=tSrK_copy_view, + tOsVt=tOsVt, + tOrVt_copy_view=tOrVt_copy_view, + ) + softmax_params = SimpleNamespace( + row_max=row_max, + row_sum=row_sum, + softmax_scale_log2=softmax_scale_log2, + ) + + # Start processing of the first n-block. + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of n_block_size. + # We also need masking on S if it's causal, for the last ceil_div(m_block_size, n_block_size) blocks. + # We will have at least 1 "masking" iteration. + mask_steps = 1 + if cutlass.const_expr(self._is_causal): + mask_steps = cute.ceil_div(self._m_block_size, self._n_block_size) + + for n_tile in cutlass.range_constexpr(mask_steps): + n_block = n_block_max - n_tile - 1 + basic_params.n_block = n_block + if cutlass.const_expr(self._is_causal): + if n_block >= 0: + self.compute_one_n_block( + basic_params, + mma_params, + gmem_copy_params, + smem_copy_params, + softmax_params, + is_first_n_block=(n_tile == 0), + in_mask_steps=True, + ) + else: + self.compute_one_n_block( + basic_params, + mma_params, + gmem_copy_params, + smem_copy_params, + softmax_params, + is_first_n_block=True, + in_mask_steps=True, + ) + + # Start async loads of rest k-tiles in reverse order, no k-residue handling needed + for n_tile in range(mask_steps, n_block_max, 1): + n_block = n_block_max - n_tile - 1 + basic_params.n_block = n_block + self.compute_one_n_block( + basic_params, + mma_params, + gmem_copy_params, + smem_copy_params, + softmax_params, + is_first_n_block=False, + in_mask_steps=False, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + # normalize acc_O by row_sum and calculate the lse + self.normalize_softmax(acc_O, row_sum) + # store acc_O + rO = cute.make_fragment_like(acc_O, self._dtype) + rO.store(acc_O.load().to(self._dtype)) + # reuse sQ's data iterator + sO = cute.make_tensor(sQ.iterator, sO_layout) + + # smem copy atom for O + smem_copy_atom_O = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self._dtype + ) + # tiled copy atom for O + smem_tiled_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma) + smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx) + taccOrO = smem_thr_copy_O.retile(rO) + taccOsO = smem_thr_copy_O.partition_D(sO) + # copy acc O from rmem to smem with the smem copy atom + cute.copy( + smem_copy_atom_O, + taccOrO, + taccOsO, + ) + gO = cute.local_tile( + mO[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOrO = cute.make_fragment_like(tOgO, self._dtype) + # sync before all smem stores are done. + cute.arch.barrier() + # load acc O from smem to rmem for wider vectorization + cute.copy( + gmem_tiled_copy_O, + tOsO, + tOrO, + ) + mcO = cute.make_identity_tensor(mO.layout.shape) + cO = cute.local_tile( + mcO[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + tOcO = gmem_thr_copy_O.partition_D(cO) + tOpO = cute.make_fragment( + cute.make_layout( + (tOgO.shape[0][1], tOgO.shape[1], tOgO.shape[2]), + stride=(tOgO.shape[2], 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tOpO.shape[0]): + for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])): + tOpO[rest_v, 0, rest_n] = cute.elem_less( + tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3] + ) + # copy acc O from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tOpO.shape[1])): + if cute.elem_less(tOcO[0, rest_m, 0][1], mO.layout.shape[1]): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None], + pred=tOpO[None, rest_m, None], + ) + + @cute.jit + def compute_one_n_block( + self, + basic_params: SimpleNamespace, + mma_params: SimpleNamespace, + gmem_copy_params: SimpleNamespace, + smem_copy_params: SimpleNamespace, + softmax_params: SimpleNamespace, + is_first_n_block: cutlass.Constexpr, + in_mask_steps: cutlass.Constexpr, + ): + """Compute one n_block of S/O. + + This function provides different variants for processing the first n block versus subsequent blocks, + as well as variants for handling masked and unmasked steps. + + :param basic_params: basic parameters + :type basic_params: SimpleNamespace + :param mma_params: mma parameters + :type mma_params: SimpleNamespace + :param gmem_copy_params: gmem copy parameters + :type gmem_copy_params: SimpleNamespace + :param smem_copy_params: smem copy parameters + :type smem_copy_params: SimpleNamespace + :param softmax_params: softmax parameters + :type softmax_params: SimpleNamespace + :param is_first_n_block: is first n block + :type is_first_n_block: cutlass.Constexpr + """ + acc_shape_S = mma_params.thr_mma.partition_shape_C( + (self._m_block_size, self._n_block_size) + ) + acc_S = cute.make_fragment(acc_shape_S, cutlass.Float32) + acc_S.fill(0.0) + + # wait for smem tile QK before mma calculation for S + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + # load smem tile V for O, special process for the first tile to avoid loading nan. + # The `if` here is a constexpr, won't be generated in the IR. + if is_first_n_block: + for n in cutlass.range_constexpr(cute.size(gmem_copy_params.tVsV.shape[1])): + if cute.elem_less( + gmem_copy_params.tKVcKV[0, n, 0][1], + basic_params.mK.layout.shape[1], + ): + cute.copy( + gmem_copy_params.gmem_tiled_copy_QKV, + gmem_copy_params.tVgV[None, n, None, basic_params.n_block], + gmem_copy_params.tVsV[None, n, None], + pred=gmem_copy_params.tKVpKV[None, n, None], + ) + else: + gmem_copy_params.tVsV[None, n, None].fill(0.0) + else: + cute.copy( + gmem_copy_params.gmem_tiled_copy_QKV, + gmem_copy_params.tVgV[None, None, None, basic_params.n_block], + gmem_copy_params.tVsV, + pred=gmem_copy_params.tKVpKV, + ) + + cute.arch.cp_async_commit_group() + # /////////////////////////////////////////////////////////////////////////////// + # S gemm calculation + # /////////////////////////////////////////////////////////////////////////////// + # load first QK k-block from smem to rmem for mma + cute.copy( + smem_copy_params.smem_tiled_copy_Q, + smem_copy_params.tSsQ[None, None, 0], + smem_copy_params.tSrQ_copy_view[None, None, 0], + ) + cute.copy( + smem_copy_params.smem_tiled_copy_K, + smem_copy_params.tSsK[None, None, 0], + smem_copy_params.tSrK_copy_view[None, None, 0], + ) + # mma for S + for k in cutlass.range_constexpr(cute.size(smem_copy_params.tSsQ.shape[2])): + # load next QK k-block from smem to rmem for mma + k_next = (k + 1) % cute.size(smem_copy_params.tSsQ.shape[2]) + cute.copy( + smem_copy_params.smem_tiled_copy_Q, + smem_copy_params.tSsQ[None, None, k_next], + smem_copy_params.tSrQ_copy_view[None, None, k_next], + ) + cute.copy( + smem_copy_params.smem_tiled_copy_K, + smem_copy_params.tSsK[None, None, k_next], + smem_copy_params.tSrK_copy_view[None, None, k_next], + ) + cute.gemm( + mma_params.tiled_mma, + acc_S, + mma_params.tSrQ[None, None, k], + mma_params.tSrK[None, None, k], + acc_S, + ) + + # wait for smem tile V for O + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + if basic_params.n_block > 0: + cute.copy( + gmem_copy_params.gmem_tiled_copy_QKV, + gmem_copy_params.tKgK[None, None, None, basic_params.n_block - 1], + gmem_copy_params.tKsK, + pred=gmem_copy_params.tKVpKV, + ) + cute.arch.cp_async_commit_group() + # /////////////////////////////////////////////////////////////////////////////// + # online softmax + # /////////////////////////////////////////////////////////////////////////////// + self.softmax_rescale_O( + basic_params, + mma_params, + softmax_params, + acc_S, + is_first_n_block, + in_mask_steps, + ) + + rP = cute.make_fragment_like(acc_S, self._dtype) + rP.store(acc_S.load().to(self._dtype)) + # /////////////////////////////////////////////////////////////////////////////// + # O gemm calculation + # /////////////////////////////////////////////////////////////////////////////// + # Convert layout of acc_S to gemm O accept layout. + # Due to the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + rP_layout_divided = cute.logical_divide(rP.layout, (None, None, 2)) + rP_mma_view = cute.make_layout( + ( + (rP_layout_divided.shape[0], rP_layout_divided.shape[2][0]), + rP_layout_divided.shape[1], + rP_layout_divided.shape[2][1], + ), + stride=( + (rP_layout_divided.stride[0], rP_layout_divided.stride[2][0]), + rP_layout_divided.stride[1], + rP_layout_divided.stride[2][1], + ), + ) + tOrS = cute.make_tensor(rP.iterator, rP_mma_view) + + # load first V k-block from smem to rmem for mma + cute.copy( + smem_copy_params.smem_tiled_copy_V, + smem_copy_params.tOsVt[None, None, 0], + smem_copy_params.tOrVt_copy_view[None, None, 0], + ) + # mma for O + for k in cutlass.range_constexpr(cute.size(tOrS.shape[2])): + # load next V k-block from smem to rmem for mma + k_next = (k + 1) % cute.size(tOrS.shape[2]) + cute.copy( + smem_copy_params.smem_tiled_copy_V, + smem_copy_params.tOsVt[None, None, k_next], + smem_copy_params.tOrVt_copy_view[None, None, k_next], + ) + cute.gemm( + mma_params.tiled_mma, + mma_params.acc_O, + tOrS[None, None, k], + mma_params.tOrVt[None, None, k], + mma_params.acc_O, + ) + + @cute.jit + def softmax_rescale_O( + self, + basic_params: SimpleNamespace, + mma_params: SimpleNamespace, + softmax_params: SimpleNamespace, + acc_S: cute.Tensor, + is_first_n_block: cutlass.Constexpr, + in_mask_steps: cutlass.Constexpr, + ): + """Apply online softmax and rescale acc_O. + + This function provides different variants for processing the first n block versus subsequent blocks, + as well as variants for handling masked and unmasked steps. + + :param basic_params: basic parameters + :type basic_params: SimpleNamespace + :param mma_params: mma parameters + :type mma_params: SimpleNamespace + :param softmax_params: softmax parameters + :type softmax_params: SimpleNamespace + :param acc_S: acc_S tensor + :type acc_S: cute.Tensor + :param is_first_n_block: is first n_block + :type is_first_n_block: cutlass.Constexpr + :param in_mask_steps: in mask steps + :type in_mask_steps: cutlass.Constexpr + """ + # Change acc_S to M,N layout view. + acc_S_mn = self._make_acc_tensor_mn_view(acc_S) + acc_O_mn = self._make_acc_tensor_mn_view(mma_params.acc_O) + row_max_prev = None + # if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row. + if cutlass.const_expr(not is_first_n_block): + row_max_prev = cute.make_fragment_like( + softmax_params.row_max, cutlass.Float32 + ) + cute.basic_copy(softmax_params.row_max, row_max_prev) + # if it is the first tile, create a mask for residual of S to -inf for softmax. + tScS_mn = None + if cutlass.const_expr(in_mask_steps): + mcS = cute.make_identity_tensor( + ( + basic_params.mQ.shape[0], + basic_params.mQ.shape[1], + basic_params.mQ.shape[2], + basic_params.mK.shape[1], + ) + ) + cS = cute.local_tile( + mcS[basic_params.batch_size, None, basic_params.num_head, None], + (self._m_block_size, self._n_block_size), + (basic_params.m_block, basic_params.n_block), + ) + tScS = mma_params.thr_mma.partition_C(cS) + tScS_mn = self._make_acc_tensor_mn_view(tScS) + + # Each iteration processes one row of acc_S + for r in cutlass.range_constexpr(cute.size(softmax_params.row_max)): + # mask residual of S with -inf + if cutlass.const_expr(in_mask_steps): + if cutlass.const_expr(not self._is_causal): + # traverse column index. + for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): + if cute.elem_less( + basic_params.mK.shape[1], tScS_mn[0, c][3] + 1 + ): + acc_S_mn[r, c] = -cutlass.Float32.inf + else: + # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. + col_idx_limit = cutlass.min( + tScS_mn[r, 0][1] + 1, basic_params.mK.shape[1] + ) + # traverse column index. + for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): + # only consider the column index, so the row index sets to 0. + if cute.elem_less(col_idx_limit, tScS_mn[0, c][3] + 1): + acc_S_mn[r, c] = -cutlass.Float32.inf + + # (n_block_size) + acc_S_row = acc_S_mn[r, None].load() + # row_max_cur_row => f32 + row_max_cur_row = acc_S_row.reduce( + cute.ReductionOp.MAX, -cutlass.Float32.inf, 0 + ) + # quad reduction for row_max + row_max_cur_row = self._threadquad_reduce_max(row_max_cur_row) + row_max_prev_row = None + # if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row. + if cutlass.const_expr(not is_first_n_block): + row_max_prev_row = row_max_prev[r] + row_max_cur_row = cute.arch.fmax(row_max_prev_row, row_max_cur_row) + if cutlass.const_expr(self._is_causal): + row_max_cur_row = ( + 0.0 if row_max_cur_row == -cutlass.Float32.inf else row_max_cur_row + ) + + # compute exp(x - max) using exp2(x * log_2(e) - max * log_2(e)) + acc_S_row_exp = cute.math.exp2( + acc_S_row * softmax_params.softmax_scale_log2 + - row_max_cur_row * softmax_params.softmax_scale_log2, + fastmath=True, + ) + # acc_S_row_sum => f32 + acc_S_row_sum = acc_S_row_exp.reduce( + cute.ReductionOp.ADD, cutlass.Float32.zero, 0 + ) + # if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum. + if cutlass.const_expr(not is_first_n_block): + prev_minus_cur_exp = cute.math.exp2( + row_max_prev_row * softmax_params.softmax_scale_log2 + - row_max_cur_row * softmax_params.softmax_scale_log2, + fastmath=True, + ) + acc_S_row_sum = ( + acc_S_row_sum + softmax_params.row_sum[r] * prev_minus_cur_exp + ) + acc_O_mn[r, None] = acc_O_mn[r, None].load() * prev_minus_cur_exp + # update row_max, row_sum and acc_S + softmax_params.row_max[r] = row_max_cur_row + softmax_params.row_sum[r] = acc_S_row_sum + acc_S_mn[r, None] = acc_S_row_exp + + @cute.jit + def normalize_softmax( + self, + acc_O: cute.Tensor, + row_sum: cute.Tensor, + ): + """Normalize acc_O by row_sum. + + :param acc_O: input tensor + :type acc_O: cute.Tensor + :param row_sum: row_sum tensor + :type row_sum: cute.Tensor + """ + # do quad reduction for row_sum. + acc_O_mn = self._make_acc_tensor_mn_view(acc_O) + for r in cutlass.range_constexpr(cute.size(row_sum)): + row_sum[r] = self._threadquad_reduce_sum(row_sum[r]) + # if row_sum is zero or nan, set acc_O_mn_row to 1.0 + acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] + + scale = ( + 1.0 if acc_O_mn_row_is_zero_or_nan else cute.arch.rcp_approx(row_sum[r]) + ) + + acc_O_mn[r, None] = acc_O_mn[r, None].load() * scale + + def _make_acc_tensor_mn_view(self, acc: cute.Tensor) -> cute.Tensor: + """make acc tensor as mn layout view + + :param acc: input tensor + :type acc: cute.Tensor + :return: acc tensor mn layout view + :rtype: cute.Tensor + """ + acc_layout_col_major = cute.make_layout(acc.layout.shape) + acc_layout_mn = cute.make_layout( + ( + ( + acc_layout_col_major.shape[0][1], + acc_layout_col_major.shape[1], + ), # MMA_M + ( + acc_layout_col_major.shape[0][0], + acc_layout_col_major.shape[2], + ), # MMA_N + ), + stride=( + ( + acc_layout_col_major.stride[0][1], + acc_layout_col_major.stride[1], + ), # MMA_M + ( + acc_layout_col_major.stride[0][0], + acc_layout_col_major.stride[2], + ), # MMA_N + ), + ) + acc_layout_mn = cute.composition(acc.layout, acc_layout_mn) + return cute.make_tensor(acc.iterator, acc_layout_mn) + + def _threadquad_reduce(self, val: cutlass.Float32, op: Callable) -> cutlass.Float32: + """thread quad reduction + + :param val: register value + :type val: cutlass.Float32 + :param op: binary operator + :type op: Callable + :return: reduced value + :rtype: cutlass.Float32 + """ + val = op( + val, + cute.arch.shuffle_sync_bfly(val, offset=2, mask=-1, mask_and_clamp=31), + ) + val = op( + val, + cute.arch.shuffle_sync_bfly(val, offset=1, mask=-1, mask_and_clamp=31), + ) + return val + + def _threadquad_reduce_max(self, val: cutlass.Float32) -> cutlass.Float32: + """thread quad reduction max + + :param val: register value + :type val: cutlass.Float32 + :return: max value + :rtype: cutlass.Float32 + """ + return self._threadquad_reduce(val, lambda x, y: cute.arch.fmax(x, y)) + + def _threadquad_reduce_sum(self, val: cutlass.Float32) -> cutlass.Float32: + """thread quad reduction sum + + :param val: register value + :type val: cutlass.Float32 + :return: sum value + :rtype: cutlass.Float32 + """ + return self._threadquad_reduce(val, lambda x, y: x + y) + + +def run( + dtype: Type[cutlass.Numeric], + batch_size: int, + seqlen_q: int, + seqlen_k: int, + num_head: int, + head_dim: int, + softmax_scale: float = 1.0, + m_block_size: int = 128, + n_block_size: int = 128, + num_threads: int = 128, + is_causal: bool = False, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + # Skip unsupported testcase + if not FlashAttentionForwardAmpere.can_implement( + dtype, + head_dim, + m_block_size, + n_block_size, + num_threads, + is_causal, + ): + raise TypeError( + f"Unsupported testcase {dtype}, {head_dim}, {m_block_size}, {n_block_size}, {num_threads}, {is_causal}" + ) + + print(f"Running Ampere SM80 FlashAttentionForward test with:") + print(f" dtype: {dtype}") + print(f" batch_size: {batch_size}") + print(f" seqlen_q: {seqlen_q}") + print(f" seqlen_k: {seqlen_k}") + print(f" num_head: {num_head}") + print(f" head_dim: {head_dim}") + print(f" softmax_scale: {softmax_scale}") + print(f" m_block_size: {m_block_size}") + print(f" n_block_size: {n_block_size}") + print(f" num_threads: {num_threads}") + print(f" is_causal: {is_causal}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + + # Create tensor Q/K/V/O + def create_tensor( + batch_size: int, + seqlen: int, + num_head: int, + head_dim: int, + dtype: Type[cutlass.Numeric], + ) -> cute.Tensor: + # (batch_size, seqlen, num_head, head_dim) + shape = (batch_size, seqlen, num_head, head_dim) + torch_tensor = ( + torch.empty(*shape, dtype=torch.int32) + .random_(-2, 2) + .to(dtype=cutlass_torch.dtype(dtype)) + .cuda() + ) + # assume input is 16B aligned. + cute_tensor = ( + from_dlpack(torch_tensor, assumed_align=16) + .mark_layout_dynamic(leading_dim=3) + .mark_compact_shape_dynamic( + mode=3, + stride_order=torch_tensor.dim_order(), + divisibility=(128 // dtype.width), + ) + ) + return cute_tensor, torch_tensor + + q, q_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + k, k_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + v, v_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + o, o_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + + fa2_fwd = FlashAttentionForwardAmpere( + head_dim, + m_block_size, + n_block_size, + num_threads, + is_causal, + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + # compile the fa2 forward pass + compiled_fa2_fwd = cute.compile(fa2_fwd, q, k, v, o, softmax_scale, current_stream) + + if not skip_ref_check: + compiled_fa2_fwd(q, k, v, o, softmax_scale, current_stream) + torch.cuda.synchronize() + q_ref = q_torch.permute(0, 2, 1, 3) + k_ref = k_torch.permute(0, 2, 1, 3) + v_ref = v_torch.permute(0, 2, 1, 3) + torch.backends.cuda.enable_flash_sdp(enabled=True) + ref_o = torch.nn.functional.scaled_dot_product_attention( + q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal + ).permute(0, 2, 1, 3) + torch.testing.assert_close(o_torch.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04) + print("Results verified successfully!") + + def generate_tensors(): + q_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + k_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + v_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + o_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + return testing.JitArguments( + q_workspace, + k_workspace, + v_workspace, + o_workspace, + softmax_scale, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + q_torch.numel() * q_torch.element_size() + + k_torch.numel() * k_torch.element_size() + + v_torch.numel() * v_torch.element_size() + + o_torch.numel() * o_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_fa2_fwd, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return avg_time_us # Return execution time in microseconds + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="example of flash attention v2 with CuTe on GPU" + ) + parser.add_argument("--dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--seqlen_q", type=int, default=8192) + parser.add_argument("--seqlen_k", type=int, default=8192) + parser.add_argument("--num_head", type=int, default=16) + parser.add_argument("--head_dim", type=int, default=128) + parser.add_argument("--softmax_scale", type=float, default=0.5) + parser.add_argument("--m_block_size", type=int, default=128) + parser.add_argument("--n_block_size", type=int, default=64) + parser.add_argument("--num_threads", type=int, default=128) + parser.add_argument("--is_causal", action="store_true", help="Enable causal mask") + parser.add_argument("--warmup_iterations", type=int, default=3) + parser.add_argument("--iterations", type=int, default=10) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference check" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + run( + args.dtype, + args.batch_size, + args.seqlen_q, + args.seqlen_k, + args.num_head, + args.head_dim, + args.softmax_scale, + args.m_block_size, + args.n_block_size, + args.num_threads, + args.is_causal, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/sgemm.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/sgemm.py new file mode 100644 index 0000000000000000000000000000000000000000..e7722a8d62a7529387dfbf738292b545738eb28a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/sgemm.py @@ -0,0 +1,866 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import time +from typing import Tuple + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +from cutlass.cute.runtime import from_dlpack + +""" +A dense FP32 SIMT GEMM (C = A * B) example using CUTE DSL. +- Matrix A is MxK, A can be row-major("K") or column-major("M") +- Matrix B is NxK, B can be row-major("N") or column-major("K") +- Matrix C is MxN, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes FPU for matrix multiply-accumulate (MMA) operations + - Use multistage pipeline to overlap computation and memory access + * Shared memory pipeline: hides gmem-to-smem latency. + * Register pipeline: overlaps shared memory-to-register transfers with + computations and eliminates false data dependencies for + better parallelism. + - Use vectorized copies + - Add padding to reduce bank conflicts in global -> shared memory copies + - Use predication to avoid unnecessary copies or copies of stale data + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies. +2. Perform matrix multiply-accumulate (MMA) operations using simple fused multiply-add atomics. +3. Store results from registers (RMEM) to global memory (GMEM). + +To run this example: + +.. code-block:: bash + + python examples/ampere/sgemm.py \ + --mnk 8192,8192,8192 \ + --a_major m --b_major n --c_major n + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/ampere/sgemm.py \ + --mnk 8192,8192,8192 \ + --a_major m --b_major n --c_major n \ + --skip_ref_check --iterations 2 + +Constraints: +* Supported input, output, and accumulator data types: fp32 +* Default tile shape is set to be 128x128x8 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned +""" + + +class SGemm: + def __init__( + self, + cta_tiler: Tuple[int, int, int] = (128, 128, 8), + num_stages: int = 3, + num_threads: int = 256, + ): + self._cta_tiler = cta_tiler + self._num_stages = num_stages + self._num_threads = num_threads + assert num_threads > 0, "needs at least one thread" + assert num_threads % 16 == 0, "multiples of 16 required for MMA thread layout" + + self._bM, self._bN, self._bK = self._cta_tiler + assert self._bM % 16 == 0, "multiple of 16 required for tile dimension M" + assert self._bN % 16 == 0, "multiple of 16 required for tile dimension N" + assert self._num_stages >= 3, "num_stages must be greater than or equal to 3" + + @cute.jit + def __call__( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + epilogue_op: cutlass.Constexpr = lambda x: x, + stream: cuda.CUstream = cuda.CUstream(cuda.CUstream_flags.CU_STREAM_DEFAULT), + ): + self.a_major_mode = utils.LayoutEnum.from_tensor(mA) + self.b_major_mode = utils.LayoutEnum.from_tensor(mB) + self.c_major_mode = utils.LayoutEnum.from_tensor(mC) + + # /////////////////////////////////////////////////////////////////////////////// + # Create layouts for shared memory for A and B: + # - sA/sB is m/n-major to vectorized copies from shared + # memory to registers. This is because the MMA layouts + # for sA/sB are also m/n-major + # - When gA/gB is k-major, pad 4 elements to reduce bank conflicts + # /////////////////////////////////////////////////////////////////////////////// + + padding_a = 4 if self.a_major_mode == utils.LayoutEnum.ROW_MAJOR else 0 + padding_b = 4 if self.b_major_mode == utils.LayoutEnum.ROW_MAJOR else 0 + sA_layout = cute.make_layout( + (self._bM, self._bK, self._num_stages), + stride=(1, (self._bM + padding_a), self._bK * (self._bM + padding_a)), + ) + sB_layout = cute.make_layout( + (self._bN, self._bK, self._num_stages), + stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)), + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Create copy layouts that will be used for asynchronous + # global memory -> shared memory copies: + # - The majorness of tA/tB follows the majorness of gA/gB + # - For k-major, these layouts will copy values one-by-one from + # from global memory, without vectorizing + # - For m/n-major, it will vectorize to a 128bit copy for faster + # data transfer between global and shared memory, as long + # as the alignment of the tensor allows it. Otherwise, it + # defaults to a non-vectorized copy + # /////////////////////////////////////////////////////////////////////////////// + + tA = cute.make_layout( + (self._num_threads // self._bK, self._bK), stride=(self._bK, 1) + ) + tB = cute.make_layout( + (self._num_threads // self._bK, self._bK), stride=(self._bK, 1) + ) + vA = cute.make_layout((1, 1)) + vB = cute.make_layout((1, 1)) + atom_async_copy_A = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mA.element_type, + num_bits_per_copy=mA.element_type.width, + ) + atom_async_copy_B = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mA.element_type, + num_bits_per_copy=mB.element_type.width, + ) + + if cutlass.const_expr(self.a_major_mode == utils.LayoutEnum.COL_MAJOR): + num_vectorized = 4 if (mA.layout.max_alignment % 16 == 0) else 1 + atom_async_copy_A = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mA.element_type, + num_bits_per_copy=mA.element_type.width * num_vectorized, + ) + major_mode_size = self._bM // num_vectorized + tA = cute.make_layout( + (major_mode_size, self._num_threads // major_mode_size), + stride=(1, major_mode_size), + ) + vA = cute.make_layout((num_vectorized, 1)) + + if cutlass.const_expr(self.b_major_mode == utils.LayoutEnum.COL_MAJOR): + num_vectorized = 4 if (mB.layout.max_alignment % 16 == 0) else 1 + atom_async_copy_B = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mA.element_type, + num_bits_per_copy=mB.element_type.width * num_vectorized, + ) + major_mode_size = self._bN // num_vectorized + tB = cute.make_layout( + (major_mode_size, self._num_threads // major_mode_size), + stride=(1, major_mode_size), + ) + vB = cute.make_layout((num_vectorized, 1)) + + tiled_copy_A = cute.make_tiled_copy_tv(atom_async_copy_A, tA, vA) + tiled_copy_B = cute.make_tiled_copy_tv(atom_async_copy_B, tB, vB) + + # /////////////////////////////////////////////////////////////////////////////// + # Create layouts for GEMM: + # We tile an MMA atom across a tensor. `atoms_layout` is the layout + # of atoms in the tiled MMA. (Because we use an `MmaUniversalOp`, + # which has a trivial 1x1x1 MMA trait, `atoms_layout` is also + # simply the thread layout for C.) `permutation_tiler` reorders the + # elements of the tensor that the tiled MMA is applied to. + # Different combinations of `atoms_layout` and `permutation_tiler` + # values can create different MMA thread-value patterns. + # + # Here, the MMA layout is set so that each thread copies four + # consecutive elements from shared memory to registers. + # `permutation_tiler_M/N` maps the elements handled by each thread + # to the permuted element in the tensor. + # For increasing indices in the tensor, the thread ID that reads it is: + # - (without permutation) ==> + # 0 1 2 ... 15 0 1 2 ... 15 0 1 2 ... 15 0 1 2 ... 15 ...... + # - (with permutation) ==> + # 0 0 0 0 1 1 1 1 2 2 2 2 ... 15 15 15 15 0 0 0 0 1 1 1 1 ...... + # /////////////////////////////////////////////////////////////////////////////// + atoms_layout = cute.make_layout( + (self._num_threads // 16, 16, 1), stride=(16, 1, 0) + ) + if cutlass.const_expr(self.c_major_mode == utils.LayoutEnum.COL_MAJOR): + atoms_layout = cute.make_layout( + (16, self._num_threads // 16, 1), stride=(1, 16, 0) + ) + op = cute.nvgpu.MmaUniversalOp(cutlass.Float32) + permutation_tiler_M = cute.make_layout( + (atoms_layout.shape[0], 4), stride=(4, 1) + ) + permutation_tiler_N = cute.make_layout( + (atoms_layout.shape[1], 4), stride=(4, 1) + ) + tiled_mma = cute.make_tiled_mma( + op, + atoms_layout, + permutation_mnk=(permutation_tiler_M, permutation_tiler_N, None), + ) + + # grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, 1) + grid_dim = *cute.ceil_div(mC.shape, (self._bM, self._bN)), 1 + + self.kernel( + mA, + mB, + mC, + sA_layout, + sB_layout, + tiled_copy_A, + tiled_copy_B, + tiled_mma, + epilogue_op, + ).launch( + grid=grid_dim, + block=[cute.size(atoms_layout), 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + sA_layout: cute.Layout, + sB_layout: cute.Layout, + tiled_copy_A: cute.TiledCopy, + tiled_copy_B: cute.TiledCopy, + tiled_mma: cute.TiledMma, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + # Thread and block indices + tidx, tidy, tidz = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + tiler_coord = (bidx, bidy, None) + thr_mma = tiled_mma.get_slice(tidx) + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # gA: (BLK_M, BLK_K, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N) + # /////////////////////////////////////////////////////////////////////////////// + gA = cute.local_tile( + mA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1) + ) + gB = cute.local_tile( + mB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1) + ) + gC = cute.local_tile( + mC, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, 1, None) + ) + + # Move the pointer of gA/gB in the `-k`` direction, making the first + # tile (instead of the last one) irregular in shape when k is irregular. + # We first handle the irregular tile to avoid checking for this + # condition within the mainloop. + residue_k = mA.shape[1] - cutlass.Int32(self._bK) * gA.shape[2] + gA = cute.domain_offset((0, residue_k, 0), gA) + gB = cute.domain_offset((0, residue_k, 0), gB) + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread. + # sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE) + # tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k) + # tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE) + # /////////////////////////////////////////////////////////////////////////////// + # Create shared memory buffer + smem = cutlass.utils.SmemAllocator() + sA = smem.allocate_tensor(mA.element_type, sA_layout, 16) + sB = smem.allocate_tensor(mB.element_type, sB_layout, 16) + thr_copy_A = tiled_copy_A.get_slice(tidx) + thr_copy_B = tiled_copy_B.get_slice(tidx) + tAgA = thr_copy_A.partition_S(gA) + tAsA = thr_copy_A.partition_D(sA) + tBgB = thr_copy_B.partition_S(gB) + tBsB = thr_copy_B.partition_D(sB) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when the problem shape + # isn't a multiple of the tile shape. If tApA/B[i] is 0, then do not + # do the copy atom associated with index i. + # cA: (BLK_M, BLK_K) => (blk_m, blk_k) + # cB: (BLK_N, BLK_K) => (blk_n, blk_k) + # tAcA: (CPY, CPY_M, CPY_K) => (blk_m, blk_k) + # tBcB: (CPY, CPY_N, CPY_K) => (blk_n, blk_k) + # tApA: (rest_v, CPY_M, CPY_K), stride=(..., ..., 0) + # tBpB: (rest_v, CPY_N, CPY_K), stride=(..., ..., 0) + # CPY = (atom_v, rest_v) + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for sA and sB, used for predication + mcA = cute.make_identity_tensor(mA.shape) + mcB = cute.make_identity_tensor(mB.shape) + cA = cute.local_tile( + mcA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1) + ) + cB = cute.local_tile( + mcB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1) + ) + cA = cute.domain_offset((0, residue_k, 0), cA) + cB = cute.domain_offset((0, residue_k, 0), cB) + # Repeat the partitioning with identity layouts + tAcA = thr_copy_A.partition_S(cA) + tBcB = thr_copy_B.partition_S(cB) + # Allocate predicate tensors for m and n + tApA = cute.make_fragment( + cute.make_layout( + ( + tAsA.shape[0][1], + cute.size(tAsA, mode=[1]), + cute.size(tAsA, mode=[2]), + ), + stride=(cute.size(tAsA, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + tBpB = cute.make_fragment( + cute.make_layout( + ( + tBsB.shape[0][1], + cute.size(tBsB, mode=[1]), + cute.size(tBsB, mode=[2]), + ), + stride=(cute.size(tBsB, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + # Allocate predicate tensors for m, n and k for residue k-tile + tApA_residue_k = cute.make_fragment( + cute.make_layout( + ( + tAsA.shape[0][1], + cute.size(tAsA, mode=[1]), + cute.size(tAsA, mode=[2]), + ), + stride=( + cute.size(tAsA, mode=[1]) * cute.size(tAsA, mode=[2]), + cute.size(tAsA, mode=[2]), + 1, + ), + ), + cutlass.Boolean, + ) + tBpB_residue_k = cute.make_fragment( + cute.make_layout( + ( + tBsB.shape[0][1], + cute.size(tBsB, mode=[1]), + cute.size(tBsB, mode=[2]), + ), + stride=( + cute.size(tBsB, mode=[1]) * cute.size(tBsB, mode=[2]), + cute.size(tBsB, mode=[2]), + 1, + ), + ), + cutlass.Boolean, + ) + # Set predicates for m/n bounds for mainloop + for rest_v in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rest_v, m, 0] = cute.elem_less( + tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0] + ) + for rest_v in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rest_v, n, 0] = cute.elem_less( + tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0] + ) + + # Set predicates for m/n/k bounds for residue k tile + for rest_v in range(tApA_residue_k.shape[0]): + for m in range(tApA_residue_k.shape[1]): + for k in range(tApA_residue_k.shape[2]): + coord_A = tAcA[(0, rest_v), m, k, 0] + tApA_residue_k[rest_v, m, k] = cute.elem_less( + (coord_A[0], cutlass.Int32(-1)), (mA.shape[0], coord_A[1]) + ) + for rest_v in range(tBpB_residue_k.shape[0]): + for n in range(tBpB_residue_k.shape[1]): + for k in range(tBpB_residue_k.shape[2]): + coord_B = tBcB[(0, rest_v), n, k, 0] + tBpB_residue_k[rest_v, n, k] = cute.elem_less( + (coord_B[0], cutlass.Int32(-1)), (mB.shape[0], coord_B[1]) + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prefetch Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads for 0th k-tile, where we take care of the k-residue + k_pipe_max = cute.size(tAsA, mode=[3]) + k_tile_count = cute.size(tAgA, mode=[3]) + gmem_pipe_read = cutlass.Int32(0) + cute.copy( + tiled_copy_A, + tAgA[None, None, None, gmem_pipe_read], + tAsA[None, None, None, 0], + pred=tApA_residue_k, + ) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, gmem_pipe_read], + tBsB[None, None, None, 0], + pred=tBpB_residue_k, + ) + cute.arch.cp_async_commit_group() + gmem_pipe_read = ( + gmem_pipe_read + 1 + if gmem_pipe_read + 1 < k_tile_count + else cutlass.Int32(0) + ) + # Start async loads for 1st k-tile onwards, no k-residue handling needed + for k_tile in range(1, k_pipe_max - 1): + if k_tile < k_tile_count: + cute.copy( + tiled_copy_A, + tAgA[None, None, None, gmem_pipe_read], + tAsA[None, None, None, k_tile], + pred=tApA, + ) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, gmem_pipe_read], + tBsB[None, None, None, k_tile], + pred=tBpB, + ) + + gmem_pipe_read = ( + gmem_pipe_read + 1 + if gmem_pipe_read + 1 < k_tile_count + else cutlass.Int32(0) + ) + cute.arch.cp_async_commit_group() + + # all tiles have been copied from global memory, so clear the + # predicate tensor + if k_tile_count < k_pipe_max: + for rest_v in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rest_v, m, 0] = cutlass.Boolean(0) + for rest_v in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rest_v, n, 0] = cutlass.Boolean(0) + + # /////////////////////////////////////////////////////////////////////////////// + # Define A/B partitioning and C accumulators. + # /////////////////////////////////////////////////////////////////////////////// + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCgC = thr_mma.partition_C(gC) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC) + # Clear the accumulator + tCrC.fill(0.0) + + # Current pipe index in smem to read from / write to + smem_pipe_read = cutlass.Int32(0) + smem_pipe_write = cutlass.Int32(k_pipe_max - 1) + + tCsA_p = tCsA[None, None, None, smem_pipe_read] + tCsB_p = tCsB[None, None, None, smem_pipe_read] + + # /////////////////////////////////////////////////////////////////////////////// + # PREFETCH register pipeline + # /////////////////////////////////////////////////////////////////////////////// + k_block_max = cute.size(tCrA, mode=[2]) + + if k_block_max > 1: + # Wait until our first prefetched tile is loaded in + cute.arch.cp_async_wait_group(k_pipe_max - 2) + cute.arch.barrier() + # Prefetch the first rmem from the first k-tile + cute.autovec_copy(tCsA_p[None, None, 0], tCrA[None, None, 0]) + cute.autovec_copy(tCsB_p[None, None, 0], tCrB[None, None, 0]) + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # 1. Shared memory pipeline (gmem -> smem): + # The default smem pipeline depth is 3, meaning that for shared + # memory buffers, we allocate three times the size described by the + # CTA tiler. We prefetch 2 of these buffers before entering the main + # loop. Considering only the transfer from global memory to shared + # memory, the general structure of the mainloop is: + # (1) copy k-tile from gmem to smem; + # (2) perform gemm computation on k-tile; + # (3) wait for the next copy to finish. + # The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command + # waits for the number of unfinished 'copy' to be <= 1. The advantage + # of this approach is that it allows for simultaneous production + # (i.e., step (1)) and consumption (i.e., step (2)) of smem. + # A common misconception is to prefetch N buffers and rewrite + # the pipeline logic to wait on N-1 pending copies. The disadvantage + # of this approach is that it requires fully consuming a buffer in + # order to open an empty buffer for the next copy. + # 2. Register pipeline (smem -> register): + # Similarly, the register pipeline produces i+1, consumes i, and + # produces i+2... Notably, i and i+1 do not use the same register, + # eliminating dependencies on the same register for better parallelism. + # 3. Combining the smem and register pipelines results in the mainloop. + # /////////////////////////////////////////////////////////////////////////////// + + for _ in range(k_tile_count): + for k_block in range(k_block_max, unroll_full=True): + if k_block == k_block_max - 1: + tCsA_p = tCsA[None, None, None, smem_pipe_read] + tCsB_p = tCsB[None, None, None, smem_pipe_read] + cute.arch.cp_async_wait_group(k_pipe_max - 2) + cute.arch.barrier() + + # Load A, B from shared memory to registers for k_block + 1 + k_block_next = (k_block + 1) % k_block_max # static + cute.autovec_copy( + tCsA_p[None, None, k_block_next], + tCrA[None, None, k_block_next], + ) + cute.autovec_copy( + tCsB_p[None, None, k_block_next], + tCrB[None, None, k_block_next], + ) + + # Fetch next A: To better interleave global memory access and + # compute instructions, we intentionally use the sequence: + # copy A, perform GEMM, then copy B. + if k_block == 0: + cute.copy( + tiled_copy_A, + tAgA[None, None, None, gmem_pipe_read], + tAsA[None, None, None, smem_pipe_write], + # Use predicates because the m-mode may be irregular + pred=tApA, + ) + + # Thread-level register gemm for k_block + cute.gemm( + tiled_mma, + tCrC, + tCrA[None, None, k_block], + tCrB[None, None, k_block], + tCrC, + ) + + # Fetch next B and update smem pipeline read/write + if k_block == 0: + cute.copy( + tiled_copy_B, + tBgB[None, None, None, gmem_pipe_read], + tBsB[None, None, None, smem_pipe_write], + # Use predicates because the n-mode may be irregular + pred=tBpB, + ) + cute.arch.cp_async_commit_group() + smem_pipe_write = smem_pipe_read + smem_pipe_read = smem_pipe_read + 1 + if smem_pipe_read == k_pipe_max: + smem_pipe_read = cutlass.Int32(0) + # After copying all tiles, we avoid clearing the predicate + # tensor in the `mainloop` to prevent increasing its + # instruction count. Instead, we continue copying the + # first tile, though it won't be used. The 0-th tile is not + # copied due to its irregular shape, which could lead to + # illegal memory accesses. + gmem_pipe_read = ( + gmem_pipe_read + 1 + if gmem_pipe_read + 1 < k_tile_count + else cutlass.Int32(1) + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # Applies the epilogue operation to the accumulated results and copies + # them without vectorization. + # /////////////////////////////////////////////////////////////////////////////// + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + tCrC.store(epilogue_op(tCrC.load())) + + # predicate + cC = cute.make_identity_tensor(gC.shape) + tCpC = thr_mma.partition_C(cC) + predC = cute.make_fragment(tCrC.layout, cutlass.Boolean) + residue_m = mC.shape[0] - cutlass.Int32(self._bM) * bidx + residue_n = mC.shape[1] - cutlass.Int32(self._bN) * bidy + for i in range(cute.size(tCrC.shape)): + predC[i] = cute.elem_less(tCpC[i], (residue_m, residue_n)) + numIterM = cute.size(tCrC, mode=[1]) + numIterN = cute.size(tCrC, mode=[2]) + atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mC.element_type) + cute.copy(atom, tCrC, tCgC, pred=predC) + return + + +def run( + mnk: Tuple[int, int, int], + a_major: str, + b_major: str, + c_major: str, + static_shape: bool = False, + warmup_iterations: int = 2, + iterations: int = 100, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Execute SIMT GEMM operation and benchmark performance. + + :param mnk: GEMM problem size (M, N, K, L) + :type mnk: Tuple[int, int, int, int] + :param a_major: Memory layout of tensor A + :type a_major: str + :param b_major: Memory layout of tensor B + :type b_major: str + :param c_major: Memory layout of tensor C + :type c_major: str + :param static_shape: Whether to use static shape optimization, defaults to False + :type static_shape: bool, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 2 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 100 + :type iterations: int, optional + :param skip_ref_check: Skip validation against reference implementation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + print(f"Running Ampere SIMT GEMM example:") + print(f"mnk: {mnk}") + print(f"A major: {a_major}, B major: {b_major}, C major: {c_major}") + print(f"Static shape: {static_shape}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + M, N, K = mnk + + # Create and permute tensor A/B/C + def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype): + # is_mode0_major: (mode1, mode0) -> (mode0, mode1) + # else: (mode0, mode1) -> (mode0, mode1) + shape = (mode1, mode0) if is_mode0_major else (mode0, mode1) + permute_order = (1, 0) if is_mode0_major else (0, 1) + + return ( + torch.empty(*shape, dtype=torch.int32) + .random_(-5, 5) + .to(dtype=dtype) + .permute(permute_order) + .cuda() + ) + + a = create_and_permute_tensor(M, K, a_major == "m", torch.float32) + b = create_and_permute_tensor(N, K, b_major == "n", torch.float32) + c = create_and_permute_tensor(M, N, c_major == "m", torch.float32) + + divisibility_a = a.shape[1] if a_major == "k" else a.shape[0] + divisibility_b = b.shape[1] if b_major == "k" else b.shape[0] + divisibility_c = c.shape[1] if c_major == "n" else c.shape[0] + + a_tensor = ( + from_dlpack(a, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if a_major == "k" else 0), + divisibility=divisibility_a, + ) + ) + + b_tensor = ( + from_dlpack(b, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if b_major == "k" else 0), + divisibility=divisibility_b, + ) + ) + + c_tensor = ( + from_dlpack(c, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if c_major == "n" else 0), + divisibility=divisibility_c, + ) + ) + + sgemm = SGemm() + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + + print("Compiling kernel with cute.compile ...") + start_time = time.time() + compiled_fn = cute.compile( + sgemm, + a_tensor, + b_tensor, + c_tensor, + stream=current_stream, + ) + compilation_time = time.time() - start_time + print(f"Compilation time: {compilation_time:.4f} seconds") + + print("Executing GEMM kernel...") + + if not skip_ref_check: + compiled_fn(a_tensor, b_tensor, c_tensor) + torch.cuda.synchronize() + print("Verifying results...") + ref = torch.einsum("mk,nk->mn", a, b) + torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + # Create new tensors for each workspace to ensure cold L2 cache + a_workspace = create_and_permute_tensor(M, K, a_major == "m", torch.float32) + b_workspace = create_and_permute_tensor(N, K, b_major == "n", torch.float32) + c_workspace = create_and_permute_tensor(M, N, c_major == "m", torch.float32) + + if static_shape: + a_tensor_workspace = ( + from_dlpack(a_workspace, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if a_major == "k" else 0), + divisibility=divisibility_a, + ) + ) + else: + a_tensor_workspace = from_dlpack(a_workspace, assumed_align=16) + + b_tensor_workspace = ( + from_dlpack(b_workspace, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if b_major == "k" else 0), + divisibility=divisibility_b, + ) + ) + + c_tensor_workspace = ( + from_dlpack(c_workspace, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if c_major == "n" else 0), + divisibility=divisibility_c, + ) + ) + + return testing.JitArguments( + a_tensor_workspace, b_tensor_workspace, c_tensor_workspace, current_stream + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a.numel() * a.element_size() + + b.numel() * b.element_size() + + c.numel() * c.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_fn, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + # Print execution results + print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") + + return avg_time_us # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--mnk", type=parse_comma_separated_ints, default=(256, 256, 64) + ) + parser.add_argument("--a_major", choices=["k", "m"], default="k") + parser.add_argument("--b_major", choices=["k", "n"], default="k") + parser.add_argument("--c_major", choices=["n", "m"], default="n") + parser.add_argument("--warmup_iterations", default=2, type=int) + parser.add_argument("--iterations", default=100, type=int) + parser.add_argument("--static_shape", action="store_true") + parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + print("Running SIMT GEMM example:") + + torch.manual_seed(1024) + + run( + args.mnk, + args.a_major, + args.b_major, + args.c_major, + args.static_shape, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/smem_allocator.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/smem_allocator.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f5c1e03f78c1f1656e06c34e4d9a5007af6b64 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/smem_allocator.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import cutlass.cute as cute +import cutlass +import torch +import numpy as np +from cutlass.cute.runtime import from_dlpack + +""" +A Shared Memory Allocator Example on NVIDIA Ampere architecture using CuTe DSL. + +This example demonstrates how to allocate and manage shared memory in JIT kernels by using the SmemAllocator in CuTe DSL. +It shows various ways to allocate different data structures in shared memory: + +1. Struct allocation with natural and strict alignment +2. Raw memory block allocation with custom alignment +3. Array allocation with automatic alignment +4. Tensor allocation with layout specification + +The example includes: +- Shared storage struct with mixed alignment requirements +- Memory allocation patterns for different data types +- Tensor operations on allocated memory + +To run this example: + +.. code-block:: bash + + python examples/ampere/smem_allocator.py + +The example will allocate shared memory, perform tensor operations, and verify the results. +""" + + +@cute.struct +class complex: + real: cutlass.Float32 + imag: cutlass.Float32 + + +# SharedStorage size is 512, alignment is 128 +@cute.struct +class SharedStorage: + # struct elements with natural alignment + a: cute.struct.MemRange[cutlass.Float32, 32] # array + b: cutlass.Int64 # saclar + c: complex # nested struct + # struct elements with strict alignment + x: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, 32], + 128, + ] + y: cute.struct.Align[cutlass.Int32, 8] + z: cute.struct.Align[complex, 16] + + +@cute.kernel +def kernel( + const_a: cutlass.Constexpr, + dst_a: cute.Tensor, + const_b: cutlass.Constexpr, + dst_b: cute.Tensor, + const_c: cutlass.Constexpr, + dst_c: cute.Tensor, +): + # Note: SMEM_SIZE bytes (specified in kernel().launch(smem=...)) can be reserved for developer to utilize + # Note: alignment of inital allocator base ptr is 1024 + allocator = cutlass.utils.SmemAllocator() + # base ptr of allocator points at: SMEM_ADDR_START (the starting address of available shared memory) + + # -- Allocate a struct -- + # Note: when specified alignment, max(alignment, alignof(struct)) will be applied + # reserves the section of struct in smem, elements in the struct can be accessed by ptr + struct_in_smem = allocator.allocate(SharedStorage) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_STRUCT = SMEM_ADDR_START + aligned_size(struct) + + # -- Allocate a block of memory -- + # reserves a section of 64 bytes in smem, align to 128 bytes, returns the section base ptr + section_in_smem = allocator.allocate(64, byte_alignment=128) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_SECTION = SMEM_ADDR_AFTER_STRUCT + aligned_size(section) + + # -- Allocate an array -- + # reserves an int64 array of size 14 in smem, returns the array base ptr + array_in_smem = allocator.allocate_array(element_type=cutlass.Int64, num_elems=14) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_ARRAY = SMEM_ADDR_AFTER_SECTION + aligned_size(array) + + # -- Allocate a tensor -- + # Note: use cute.ComposedLayout or cute.Layout to specify layout of tensor + # Note: iterator swizzle with swizzle layout is currently not supported + layout = cute.make_layout((16, 2)) + tensor_in_smem = allocator.allocate_tensor( + element_type=cutlass.Float32, layout=layout, byte_alignment=32, swizzle=None + ) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_TENSOR = SMEM_ADDR_AFTER_ARRAY + aligned_size(tensor) + + # ptr> + # ptr> + # ptr> + print(struct_in_smem.a.data_ptr()) + print(struct_in_smem.b) + print(struct_in_smem.c.real) + # ptr> + print(section_in_smem) + # ptr> + print(array_in_smem) + # tensor> o (16,4):(1,16)> + print(tensor_in_smem) + + # fill MemRange tensor in struct and copy to dst + a_tensor = struct_in_smem.a.get_tensor(cute.make_layout((8, 4))) + a_tensor.fill(const_a) + cute.printf("cute.struct.MemRange: {}", a_tensor) + dst_a.store(a_tensor.load()) + + # convert block of smem to fill tensor and copy to dst + layout = cute.make_layout((8, 2)) + sec_ptr = cute.recast_ptr(section_in_smem, dtype=cutlass.Float32) + sec_tensor = cute.make_tensor(sec_ptr, layout) + sec_tensor.fill(const_b) + cute.printf("block of memory: {}", sec_tensor) + dst_b.store(sec_tensor.load()) + + # fill allocated tensor in smem and copy to dst + tensor_in_smem.fill(const_c) + cute.printf("tensor in smem: {}", tensor_in_smem) + dst_c.store(tensor_in_smem.load()) + + +@cute.jit +def run_allocation_kernel( + const_a: cutlass.Constexpr, + dst_a: cute.Tensor, + const_b: cutlass.Constexpr, + dst_b: cute.Tensor, + const_c: cutlass.Constexpr, + dst_c: cute.Tensor, +): + # additional size for the example, 64(section) + 112(array) + 128(tensor) < 384 + addtional_bytes = 384 + # Note: launch shared memory size is: SMEM_SIZE = 512 + 384 = 896 bytes + kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch( + grid=(1, 1, 1), + block=(1, 1, 1), + smem=SharedStorage.size_in_bytes() + addtional_bytes, + ) + + +def veify_allocation_kernel(const_a, const_b, const_c): + dst_a = torch.zeros((8, 4), dtype=torch.float32, device="cuda") + dst_b = torch.zeros((8, 2), dtype=torch.float32, device="cuda") + dst_c = torch.zeros((16, 2), dtype=torch.float32, device="cuda") + + run_allocation_kernel( + const_a, + from_dlpack(dst_a), + const_b, + from_dlpack(dst_b), + const_c, + from_dlpack(dst_c), + ) + + np.testing.assert_equal(const_a, dst_a.detach().cpu().numpy()[0]) + np.testing.assert_equal(const_b, dst_b.detach().cpu().numpy()[0]) + np.testing.assert_equal(const_c, dst_c.detach().cpu().numpy()[0]) + + +if __name__ == "__main__": + # prepare cuda context + cutlass.cuda.initialize_cuda_context() + # An example for shared memory allocation + const_a = 0.5 + const_b = 1.0 + const_c = 2.0 + veify_allocation_kernel(const_a, const_b, const_c) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/tensorop_gemm.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/tensorop_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..413d1fbf0a239f8e2655ef5a530d7aa9bd6b2988 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/tensorop_gemm.py @@ -0,0 +1,1012 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import math +import time +from typing import Tuple, Type + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +from cutlass.cute.runtime import from_dlpack + +""" +A dense GEMM (C = A * B) example for the NVIDIA Ampere architecture using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations + - Threadblock rasterization to improve data re-use + - Supports multi-stage pipeline to overlap computation and memory access + - Implements shared memory buffering for epilogue to increase coalesced global memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies. +2. Perform matrix multiply-accumulate (MMA) operations. +3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM). + +The Ampere tensor core instruction used operates as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Perform MMA operation and store the result in Accumulator(register) + +To run this example: + +.. code-block:: bash + + python examples/ampere/tensorop_gemm.py \ + --mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \ + --ab_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major m --b_major n --c_major n + +The above example command computes with M=8192, N=8192, K=8192, +batch_count=1. The atom layout's shape is 2x2x1 and the input, mma +accumulator, and output data type are set as fp16, fp32 and fp16, +respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/ampere/tensorop_gemm.py \ + --mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \ + --ab_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major m --b_major n --c_major n \ + --skip_ref_check --iterations 2 + +Constraints: +* Supported input and output data types: fp16 +* Support accumulator data types: f32 +* Default tile shape is set to be 128x128x32 +* Atom layout's MNK shape is set so that tile shape can be divided by MMA + instruction shape +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 8 +""" + + +class TensorOpGemm: + def __init__( + self, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + atom_layout_mnk: Tuple[int, int, int], + ): + self.ab_dtype = ab_dtype + self.c_dtype = c_dtype + self.acc_dtype = acc_dtype + self.cta_tiler = (128, 128, 32) + self.num_stages = 3 + self.atom_layout_mnk = atom_layout_mnk + atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk + self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32 + + self.bM, self.bN, self.bK = self.cta_tiler + self.mma_inst_shape = (16, 8, 16) + mmaM, mmaN, mmaK = self.mma_inst_shape + + assert ( + self.bM % (atom_lay_M * mmaM) == 0 + ), "bM must be divisible by MMA instruction" + assert ( + self.bN % (atom_lay_N * mmaN) == 0 + ), "bN must be divisible by MMA instruction" + assert atom_lay_K == 1, "this example does not support atom layout K > 1" + assert self.bK % mmaK == 0, "bK must be divisible by MMA instruction" + assert self.num_stages >= 3, "num_stages must be greater than or equal to 3" + + @cute.jit + def __call__( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + # The grid divides the problems's M, N, and L dimensions by the + # respective modes of the tile shape (bM, bN, 1). The K dimension is + # handled within a block via a multistage process. + + self.a_major_mode = utils.LayoutEnum.from_tensor(mA) + self.b_major_mode = utils.LayoutEnum.from_tensor(mB) + self.c_major_mode = utils.LayoutEnum.from_tensor(mC) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: + # /////////////////////////////////////////////////////////////////////////////// + + # Creates a layout with the size required for the provided tile + # size and num stages (stages are used for K dimension) that is also + # sectioned into 64x8 or 8x32 layout atoms. The swizzle is set so that + # the atom for the shared memory -> register copy does not encounter + # bank conflicts + + # assume the input is 16B align + ab_copy_bits = 128 + sA_layout = self._make_smem_layout_AB( + mA.element_type, + self.a_major_mode, + ab_copy_bits, + (self.cta_tiler[0], self.cta_tiler[2], self.num_stages), + ) + sB_layout = self._make_smem_layout_AB( + mB.element_type, + self.b_major_mode, + ab_copy_bits, + (self.cta_tiler[1], self.cta_tiler[2], self.num_stages), + ) + + # Creates a similar layout but without num_stages or layout atoms + sC_layout = self._make_smem_layout_C( + mC.element_type, + self.c_major_mode, + ab_copy_bits, + (self.cta_tiler[0], self.cta_tiler[1]), + ) + + # Shared memory allocated for operations with A, B will be + # overwritten for operations on C. This is to improve performance + # by reducing the size of shared memory requested by each block + smem_size = max( + cute.size_in_bytes(mC.element_type, sC_layout), + cute.size_in_bytes(mA.element_type, sA_layout) + + cute.size_in_bytes(mB.element_type, sB_layout), + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled copy: + # The majorness of tA/tB/tC follows the majorness of gA/gB/gC, + # enabling merged accesses to global memory for faster data + # transfer between global and shared memory. + # /////////////////////////////////////////////////////////////////////////////// + + # Create a copy atom for a global to shared memory asynchronous copy + atom_async_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp( + cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL + ), + mA.element_type, + num_bits_per_copy=ab_copy_bits, + ) + + # Create thread layouts for tiled copy from the copy atom where the + # thread layout simply follows the leading dimension of the tensor + tiled_copy_A = self._make_gmem_tiled_copy_AB( + atom_async_copy, mA.element_type, self.a_major_mode, ab_copy_bits + ) + tiled_copy_B = self._make_gmem_tiled_copy_AB( + atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits + ) + + # Creates a synchronous copy atom and thread layouts for the epilogue + c_copy_bits = 128 + atom_sync_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mC.element_type, + num_bits_per_copy=c_copy_bits, + ) + tiled_copy_C = self._make_gmem_tiled_copy_C( + atom_sync_copy, mC.element_type, self.c_major_mode, c_copy_bits + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled MMA + # /////////////////////////////////////////////////////////////////////////////// + + # Creates a mma atom with 16x8x16 shape for MNK + op = cute.nvgpu.warp.MmaF16BF16Op( + self.ab_dtype, self.acc_dtype, self.mma_inst_shape + ) + + permutation_mnk = ( + self.atom_layout_mnk[0] * self.mma_inst_shape[0], + # if atom layout's N-mode is 1, to leverage the largest coalesced + # shared memory -> register copy, set the tiled mma's N mode to 16 + self.atom_layout_mnk[1] * self.mma_inst_shape[1] * 2, + self.atom_layout_mnk[2] * self.mma_inst_shape[2], + ) + + # Created a tiled mma that tiles the atom according to specified layout. + # For a 2x2x1 atom layout, the mma atom is duplicated 4 times, twice + # across M and twice across N + tC = cute.make_layout(self.atom_layout_mnk) + tiled_mma = cute.make_tiled_mma( + op, + tC, + permutation_mnk=permutation_mnk, + ) + + # grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l) + grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + + # Add threadblock rasterization to improve re-use of data + raster_factor = 1 + grid_dim_n = cute.size(grid_dim[1]) + # Thresholds picked so that it doesn't cause too many no-op CTAs + if grid_dim_n > 5: + raster_factor = 8 + elif grid_dim_n > 2: + raster_factor = 4 + elif grid_dim_n > 1: + raster_factor = 2 + rasterization_remap_grid_dim = ( + cute.size(grid_dim[0]) * raster_factor, + (cute.size(grid_dim[1]) + raster_factor - 1) // raster_factor, + cute.size(grid_dim[2]), + ) + + self.kernel( + mA, + mB, + mC, + sA_layout, + sB_layout, + sC_layout, + tiled_copy_A, + tiled_copy_B, + tiled_copy_C, + tiled_mma, + raster_factor, + epilogue_op, + ).launch( + grid=rasterization_remap_grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + ) + + @cute.kernel + def kernel( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + sA_layout: cute.ComposedLayout, + sB_layout: cute.ComposedLayout, + sC_layout: cute.ComposedLayout, + tiled_copy_A: cute.TiledCopy, + tiled_copy_B: cute.TiledCopy, + tiled_copy_C: cute.TiledCopy, + tiled_mma: cute.TiledMma, + rasterization_factor: cutlass.Int32, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + offset_tile_x, offset_tile_y = self.raster_tile( + bidx, bidy, rasterization_factor + ) + # Early exit if CTA is out of range + if grid_dim[0] <= offset_tile_x or grid_dim[1] <= offset_tile_y: + pass + else: + tiler_coord = (offset_tile_x, offset_tile_y, None) + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N) + # /////////////////////////////////////////////////////////////////////////////// + gA = cute.local_tile( + mA[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, None, 1), + ) + gB = cute.local_tile( + mB[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(None, 1, 1), + ) + gC = cute.local_tile( + mC[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, 1, None), + ) + + # By default, if the tensor k mode does not divide into the tile k + # size, then last tiles in the k dimension are irregular. + # Instead, make the first tiles irregular when k is irregular. + # This allows us to handle the irregular tile first to avoid + # checking for this condition within the mainloop. + + # residual_k is a negative number indicating the amount needed to + # shift the pointer by in dimension k + residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size( + gA, mode=[2] + ) + + # move the pointer of gA/gB in the `-k` direction + gA = cute.domain_offset((0, residual_k, 0), gA) + gB = cute.domain_offset((0, residual_k, 0), gB) + # input is 16B aligned + gA = cute.make_tensor(gA.iterator.align(16), gA.layout) + gB = cute.make_tensor(gB.iterator.align(16), gB.layout) + + # Construct identity layout for sA and sB (mirrors global tensors, + # used for predication only) + mcA = cute.make_identity_tensor(mA.layout.shape) + mcB = cute.make_identity_tensor(mB.layout.shape) + cA = cute.local_tile( + mcA[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, None, 1), + ) + cB = cute.local_tile( + mcB[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(None, 1, 1), + ) + + cA = cute.domain_offset((0, residual_k, 0), cA) + cB = cute.domain_offset((0, residual_k, 0), cB) + + # /////////////////////////////////////////////////////////////////////////////// + # Create shared memory buffers and get the appropriate fragments for this thread. + # sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE) + # tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k) + # tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE) + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory buffer + smem = cutlass.utils.SmemAllocator() + + sA = smem.allocate_tensor(mA.element_type, sA_layout, 16) + sB = smem.allocate_tensor(mB.element_type, sB_layout, 16) + sC = cute.make_tensor( + cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout + ) + + thr_copy_A = tiled_copy_A.get_slice(tidx) + thr_copy_B = tiled_copy_B.get_slice(tidx) + thr_copy_C = tiled_copy_C.get_slice(tidx) + tAgA = thr_copy_A.partition_S(gA) + tAsA = thr_copy_A.partition_D(sA) + tBgB = thr_copy_B.partition_S(gB) + tBsB = thr_copy_B.partition_D(sB) + tCsC_epilogue = thr_copy_C.partition_S(sC) + tCgC_epilogue = thr_copy_C.partition_D(gC) + + # Repeat the partitioning with identity layouts + tAcA = thr_copy_A.partition_S(cA) + tBcB = thr_copy_B.partition_S(cB) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + + # For predication over the tensors A (M/K), B (N/K), and (in the + # epilogue) C (M/N), we will compute it in a fashion similar to an + # outer product. The predication along one of the dimensions is + # evaluated and stored in a predication tensor. Then, the + # predication for the remaining dimension is handled later via an + # if/else branch at the copy. + # For A and B, predication booleans along M/N are stored in a + # predication tensor and along K is handled via a if/else branch. + + # Allocate predicate tensors for M and N. Predication is checked + # at the granularity of a copy atom, so the predicate tensor does not + # need separate booleans for individual elements within a copy + # atom (for example, the elements of tAgA.shape[0][0].) + tApA = cute.make_fragment( + cute.make_layout( + ( + tAgA.shape[0][1], + cute.size(tAgA, mode=[1]), + cute.size(tAgA, mode=[2]), + ), + stride=(cute.size(tAgA, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + tBpB = cute.make_fragment( + cute.make_layout( + ( + tBsB.shape[0][1], + cute.size(tBsB, mode=[1]), + cute.size(tBsB, mode=[2]), + ), + stride=(cute.size(tBsB, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + # Set predicates for M/N bounds + for rest_v in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rest_v, m, 0] = cute.elem_less( + tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0] + ) + for rest_v in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rest_v, n, 0] = cute.elem_less( + tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0] + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prefetch Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Clear the smem tiles to account for predicated off loads + tAsA.fill(0) + tBsB.fill(0) + cute.arch.sync_threads() + # Start async loads for the first k-tile. Here we take care of the k residue + # via if/else check along the k dimension. Because we shifted the identity tensor + # by the residue_k and because the identity tensor is a coord tensor, the + # values of any identity tensor element that is poison is less than -1 + num_smem_stages = cute.size(tAsA, mode=[3]) + k_tile_count = cute.size(tAgA, mode=[3]) + k_tile_index = cutlass.Int32(0) + + for k in range(tApA.shape[2]): + if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]): + cute.copy( + tiled_copy_A, + tAgA[None, None, k, k_tile_index], + tAsA[None, None, k, 0], + pred=tApA[None, None, k], + ) + for k in range(tBpB.shape[2]): + if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]): + cute.copy( + tiled_copy_B, + tBgB[None, None, k, k_tile_index], + tBsB[None, None, k, 0], + pred=tBpB[None, None, k], + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + + # Start async loads for rest of the k-tiles + for k_tile in range(1, num_smem_stages - 1): + if k_tile == k_tile_count: + tApA.fill(0) + tBpB.fill(0) + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile_index], + tAsA[None, None, None, k_tile], + pred=tApA, + ) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile_index], + tBsB[None, None, None, k_tile], + pred=tBpB, + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma = tiled_mma.get_slice(tidx) + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCsC = thr_mma.partition_C(sC) + tCgC = thr_mma.partition_C(gC) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC) + # Clear the accumulator + tCrC.fill(0.0) + + # /////////////////////////////////////////////////////////////////////////////// + # Copy Atom A/B retiling + # /////////////////////////////////////////////////////////////////////////////// + + # Create the copy atoms for the copy from shared memory to register + atom_copy_s2r_A = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp( + self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 + ), + mA.element_type, + ) + atom_copy_s2r_B = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp( + self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 + ), + mB.element_type, + ) + + # Creates the tiled copy so that it matches the thread-value layout + # expected by the tiled mma + tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma) + tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma) + + thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx) + thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx) + tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA) + tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA) + tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB) + tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB) + + # Current pipe index in smem to read from / write to + smem_pipe_read = 0 + smem_pipe_write = num_smem_stages - 1 + + tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] + + # /////////////////////////////////////////////////////////////////////////////// + # PREFETCH register pipeline + # /////////////////////////////////////////////////////////////////////////////// + num_k_block = cute.size(tCrA, mode=[2]) + if num_k_block > 1: + # Wait until our first prefetched tile is loaded in + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + # Prefetch the first k-block rmem from the first k-tile + cute.copy( + tiled_copy_s2r_A, + tCsA_p[None, None, 0], + tCrA_copy_view[None, None, 0], + ) + cute.copy( + tiled_copy_s2r_B, + tCsB_p[None, None, 0], + tCrB_copy_view[None, None, 0], + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # 1. Shared memory pipeline (gmem -> smem): + # The default smem pipeline depth is 3, meaning that for shared + # memory buffers, we allocate three times the size described by the + # CTA tiler. We prefetch 2 of these buffers before entering the main + # loop. Considering only the transfer from global memory to shared + # memory, the general structure of the mainloop is: + # (1) copy k-tile from gmem to smem; + # (2) perform gemm computation on k-tile; + # (3) wait for the next copy to finish. + # The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command + # waits for the number of unfinished 'copy' to be <= 1. The advantage + # of this approach is that it allows for simultaneous production + # (i.e., step (1)) and consumption (i.e., step (2)) of smem. + # A common misconception is to prefetch N buffers and rewrite + # the pipeline logic to wait on N-1 pending copies. The disadvantage + # of this approach is that it requires fully consuming a buffer in + # order to open an empty buffer for the next copy. + # 2. Register pipeline (smem -> register): + # Similarly, the register pipeline produces i+1, consumes i, and + # produces i+2... Notably, i and i+1 do not use the same register, + # eliminating dependencies on the same register for better parallelism. + # 3. Combining the smem and register pipelines results in the mainloop. + # /////////////////////////////////////////////////////////////////////////////// + for k_tile in range(k_tile_count): + for k_block in cutlass.range(num_k_block, unroll_full=True): + if k_block == num_k_block - 1: + tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + + # Load A, B from shared memory to registers for k_block + 1 + k_block_next = (k_block + 1) % num_k_block # static + cute.copy( + tiled_copy_s2r_A, + tCsA_p[None, None, k_block_next], + tCrA_copy_view[None, None, k_block_next], + ) + cute.copy( + tiled_copy_s2r_B, + tCsB_p[None, None, k_block_next], + tCrB_copy_view[None, None, k_block_next], + ) + + # Fetch next A: To better interleave global memory access and compute + # instructions, we intentionally use the sequence: copy A, perform GEMM, + # then copy B. + if k_block == 0: + if k_tile + num_smem_stages - 1 < k_tile_count: + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile_index], + tAsA[None, None, None, smem_pipe_write], + pred=tApA, + ) + + # Thread-level register gemm for k_block + cute.gemm( + tiled_mma, + tCrC, + tCrA[None, None, k_block], + tCrB[None, None, k_block], + tCrC, + ) + + # Fetch next B and update smem pipeline read/write + if k_block == 0: + if k_tile + num_smem_stages - 1 < k_tile_count: + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile_index], + tBsB[None, None, None, smem_pipe_write], + pred=tBpB, + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + smem_pipe_write = smem_pipe_read + smem_pipe_read = smem_pipe_read + 1 + if smem_pipe_read == num_smem_stages: + smem_pipe_read = 0 + + # Sync before epilogue + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue with fusion + # /////////////////////////////////////////////////////////////////////////////// + tCrD = cute.make_fragment_like(tCrC, self.c_dtype) + tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype) + + # Copy results of D back to shared memory + cute.autovec_copy(tCrD, tCsC) + + # Create coord tensor for C + ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + mcC = cute.make_identity_tensor( + ( + cute.size(ceilM) * self.cta_tiler[0], + cute.size(ceilN) * self.cta_tiler[1], + 1, + ) + ) + cC = cute.local_tile( + mcC[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, 1, None), + ) + tCcC = thr_copy_C.partition_S(cC) + + tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue) + # Wait for all writes to shared memory to finish before starting copies + # using the new layouts + cute.arch.sync_threads() + cute.autovec_copy(tCsC_epilogue, tCrC_epilogue) + + # Create predication tensor for m + tCpC = cute.make_fragment( + cute.make_layout( + ( + tCgC_epilogue.shape[0][1], + cute.size(tCgC_epilogue, mode=[1]), + cute.size(tCgC_epilogue, mode=[2]), + ), + stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + for rest_v in range(tCpC.shape[0]): + for m in range(tCpC.shape[1]): + tCpC[rest_v, m, 0] = cute.elem_less( + tCcC[(0, rest_v), m, 0][0], mC.shape[0] + ) + + # Copy to global memory using better vectorization + for rest_v in range(tCpC.shape[0]): + for n in range(tCpC.shape[2]): + if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]): + cute.copy( + tiled_copy_C, + tCrC_epilogue[None, None, n], + tCgC_epilogue[None, None, n], + pred=tCpC[None, None, n], + ) + return + + def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler): + major_mode_size = ( + smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0] + ) + major_mode_size = 64 if major_mode_size >= 64 else major_mode_size + + swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits)) + swizzle_bits = min(swizzle_bits, 3) + + layout_atom_outer = ( + cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size)) + ) + layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 3), + 0, + layout_atom_outer, + ) + layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2)) + return layout + + def _make_smem_layout_C(self, dtype, major_mode, copy_bits, smem_tiler): + major_mode_size = ( + smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0] + ) + + swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits)) + swizzle_bits = min(swizzle_bits, 3) + + layout_atom_outer = ( + cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size)) + ) + layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 4), + 0, + layout_atom_outer, + ) + + # Due to the thread layout of the mma, remove swizzle in C to + # prevent shared memory fragments owned by an single thread from + # holding swizzles + if major_mode == utils.LayoutEnum.COL_MAJOR: + layout_atom = cute.make_composed_layout( + cute.make_swizzle(0, 3, 4), 0, layout_atom_outer + ) + layout = cute.tile_to_shape( + layout_atom, + smem_tiler, + (0, 1), + ) + return layout + + def _make_gmem_tiled_copy_AB(self, atom_copy, dtype, major_mode, copy_bits): + copy_elems = copy_bits // dtype.width + shape_dim_1 = cute.size(self.bK) // copy_elems + # thread layout for copy + thread_layout = cute.make_layout( + (self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1) + ) + if major_mode != utils.LayoutEnum.ROW_MAJOR: + shape_dim_0 = cute.size(self.bM) // copy_elems + thread_layout = cute.make_layout( + (shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0) + ) + # Value layout for copy + value_layout = ( + cute.make_layout((1, copy_elems)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((copy_elems, 1)) + ) + return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout) + + def _make_gmem_tiled_copy_C(self, atom_copy, dtype, major_mode, copy_bits): + copy_elems = copy_bits // dtype.width + shape_dim_1 = cute.size(self.bN) // copy_elems + # thread layout for copy + thread_layout = cute.make_layout( + (self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1) + ) + if major_mode != utils.LayoutEnum.ROW_MAJOR: + shape_dim_0 = cute.size(self.bM) // copy_elems + thread_layout = cute.make_layout( + (shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0) + ) + value_layout = ( + cute.make_layout((1, copy_elems)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((copy_elems, 1)) + ) + return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout) + + def raster_tile(self, i, j, f): + new_i = i // f + new_j = (i % f) + (j * f) + return (new_i, new_j) + + +def run( + a_major: str, + b_major: str, + c_major: str, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + mnkl: Tuple[int, int, int, int], + atom_layout_mnk: Tuple[int, int, int], + warmup_iterations: int = 2, + iterations: int = 100, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + print(f"Running Ampere tensor core GEMM example:") + print(f"mnkl: {mnkl}") + print( + f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Atoms layout: {atom_layout_mnk}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + M, N, K, L = mnkl + + # Create and permute tensor A/B/C + def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + torch_tensor = ( + torch.empty(*shape, dtype=torch.int32) + .random_(-2, 2) + .to(dtype=cutlass_torch.dtype(dtype)) + .permute(permute_order) + .cuda() + ) + # assume input is 16B aligned + cute_tensor = ( + from_dlpack(torch_tensor, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0)) + .mark_compact_shape_dynamic( + mode=(1 if not is_mode0_major else 0), + stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0), + divisibility=(128 // dtype.width), + ) + ) + return cute_tensor, torch_tensor + + mA, a_torch = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype) + mB, b_torch = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype) + mC, c_torch = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype) + + tensor_op_gemm = TensorOpGemm( + ab_dtype, + c_dtype, + acc_dtype, + atom_layout_mnk, + ) + + print("Compiling kernel with cute.compile ...") + compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC) + + print("Executing GEMM kernel...") + + if not skip_ref_check: + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch.to(dtype=torch.float32), + b_torch.to(dtype=torch.float32), + ).to(cutlass_torch.dtype(c_dtype)) + compiled_gemm(mA, mB, mC) + print("Verifying results...") + torch.testing.assert_close(c_torch.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + a_workspace, _ = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype) + b_workspace, _ = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype) + c_workspace, _ = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype) + return testing.JitArguments(a_workspace, b_workspace, c_workspace) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch.numel() * a_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + c_torch.numel() * c_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + warmup_iterations=warmup_iterations, + iterations=iterations, + use_cuda_graphs=False, + ) + + return avg_time_us # Return execution time in microseconds + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="example of multistage block matmul with CuTe on GPU" + ) + parser.add_argument( + "--mnkl", type=parse_comma_separated_ints, default=(112, 136, 40, 1) + ) + parser.add_argument( + "--atom_layout_mnk", type=parse_comma_separated_ints, default=(2, 2, 1) + ) + parser.add_argument( + "--ab_dtype", + type=cutlass.dtype, + choices=[cutlass.Float16], + default=cutlass.Float16, + ) + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + choices=[cutlass.Float32], + default=cutlass.Float32, + ) + parser.add_argument( + "--c_dtype", + type=cutlass.dtype, + choices=[cutlass.Float16], + default=cutlass.Float16, + ) + parser.add_argument("--a_major", choices=["k", "m"], default="m") + parser.add_argument("--b_major", choices=["k", "n"], default="n") + parser.add_argument("--c_major", choices=["n", "m"], default="n") + parser.add_argument("--warmup_iterations", default=2, type=int) + parser.add_argument("--iterations", default=100, type=int) + parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + run( + args.a_major, + args.b_major, + args.c_major, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.mnkl, + args.atom_layout_mnk, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c24b84265834ffe4062e14aaa255bc0c4610c9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py @@ -0,0 +1,2466 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Type, Tuple, Union + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import from_dlpack + +""" +This example provides an experimental implementation of the SM100 batched dense blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel may change in future releases. + +A high-performance persistent batched dense blockscaled GEMM example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") +- Matrix SFA layout is filled internally according to A shape and BlockScaledBasicChunk, which has M×ceil_div(K, sf_vec_size)×L elements respectively +- Matrix SFB layout is filled internally according to B shape and BlockScaledBasicChunk, which has N×ceil_div(K, sf_vec_size)×L elements respectively + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. MMA warp: + - Load scale factor A/B from shared memory (SMEM) to tensor memory (TMEM) using tcgen05.cp instruction. + - Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Type convert C matrix to output type. + - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. + - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma.kind.block_scale instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Read scalefactor A from TMEM +- Read scalefactor B from TMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +Input arguments to this example is shown below: + +.. code-block:: bash + + python examples/blackwell/dense_blockscaled_gemm_persistent.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,1024,1 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_blockscaled_gemm_persistent.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,1024,1 \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + + +Constraints: +* Supported input data types: mxf8, mxf4, nvf4 + see detailed valid dtype combinations in below Sm100BlockScaledPersistentDenseGemmKernel class documentation +* A/B tensor must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4) +* Mma tiler M must be 128 or 256(use_2cta_instrs) +* Mma tiler N must be 128 or 256 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively. +""" + + +class Sm100BlockScaledPersistentDenseGemmKernel: + """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float32 + - Float16/BFloat16 + - Float8E4M3FN/Float8E5M2 + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 128/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + + Example: + >>> gemm = Sm100BlockScaledPersistentDenseGemmKernel( + ... sf_vec_size=16, + ... mma_tiler_mn=(256, 128), + ... cluster_shape_mn=(2, 1) + ... ) + >>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator, always set to Float32 + - sf_vec_size: Scalefactor A/B vector size. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype = cutlass.Float32 + self.sf_vec_size = sf_vec_size + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier id for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_bar_id = 0 + self.epilog_sync_bar_id = 1 + self.tmem_ptr_sync_bar_id = 2 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B/SFA/SFB + - Computing epilogue subtile + - Setting up A/B/SFA/SFB/C stage counts in shared memory + - Computing A/B/SFA/SFB/C shared memory layout + - Computing tensor memory allocation columns + """ + # Compute mma instruction shapes + mma_inst_bits_k = 256 + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mnk = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_bits_k // self.a_dtype.width, + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mnk_sfb = ( + self.mma_inst_shape_mnk[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mnk[1], 128), + self.mma_inst_shape_mnk[2], + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mnk[:2], + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mnk_sfb[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_inst_shape_mnk[0], + self.mma_inst_shape_mnk[1], + self.mma_inst_shape_mnk[2] * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mnk_sfb[0], + self.mma_inst_shape_mnk_sfb[1], + self.mma_inst_shape_mnk_sfb[2] * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.a_major_mode, + self.b_dtype, + self.b_major_mode, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, + ) + + # Compute A/B/SFA/SFB/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + + @cute.jit + def __call__( + self, + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + c_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a_tensor: Input tensor A + :type a_tensor: cute.Tensor + :param b_tensor: Input tensor B + :type b_tensor: cute.Tensor + :param sfa_tensor: Scale factor tensor A + :type sfa_tensor: cute.Tensor + :param sfb_tensor: Scale factor tensor B + :type sfb_tensor: cute.Tensor + :param c_tensor: Output tensor C + :type c_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a_tensor.element_type + self.b_dtype: Type[cutlass.Numeric] = b_tensor.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa_tensor.element_type + self.c_dtype: Type[cutlass.Numeric] = c_tensor.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_tensor) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, self.sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, self.sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mnk[:2], + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mnk_sfb[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_tensor, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa_tensor, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Setup TMA store for C + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c_tensor, + epi_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c_tensor, + self.cta_tile_shape_mnk, + self.cluster_shape_mn, + max_active_clusters, + ) + + self.buffer_align_bytes = 1024 + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/SFA/SFB/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + + # + # Compute multicast mask for A/B/SFA/SFB buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA load SFA partition_S/D + sfa_cta_layout = a_cta_layout + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMA load SFB partition_S/D + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta + ) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + while work_tile.is_valid_tile: + + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # ((atom_v, rest_v), RestK) + tAgSFA_slice = tAgSFA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # TMA load A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, ab_producer_state.count)], + tAsSFA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfa_full_mcast_mask, + ) + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfb_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor + # + # Make accumulator tmem tensor + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + # + # Partition for S2T copy of SFA/SFB + # + tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ) + tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(k_tile_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + ab_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( + self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + ) + + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC_partitioned = ( + self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + while work_tile.is_valid_tile: + + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads + ) + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + a_major_mode: tcgen05.OperandMajorMode, + b_dtype: Type[cutlass.Numeric], + b_major_mode: tcgen05.OperandMajorMode, + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param a_major_mode: Major mode of operand A. + :type a_major_mode: tcgen05.OperandMajorMode + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param b_major_mode: Major mode of operand B. + :type b_major_mode: tcgen05.OperandMajorMode + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param sf_dtype: Data type of Scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Scale factor vector size. + :type sf_vec_size: int + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 + + # Calculate smem layout and size for one stage of A, B, SFA, SFB and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B/SFA/SFB stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B/SFA/SFB stage + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes and sf_vec_size are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: + is_valid = False + + # Check valid c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major dimension of the A tensor + :type a_major: str + :param b_major: The major dimension of the B tensor + :type b_major: str + :param c_major: The major dimension of the C tensor + :type c_major: str + + :return: True if the layouts are valid, False otherwise + :rtype: bool + """ + is_valid = True + + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not mma_tiler_mn[0] in [128, 256]: + is_valid = False + if not mma_tiler_mn[1] in [128, 256]: + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype, sf_dtype, sf_vec_size, c_dtype + ): + can_implement = False + # Skip unsupported layouts + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts( + ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + return can_implement + + +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_tensor: cute.Tensor, + sf_mma_tensor: cute.Tensor, +): + """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" + # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) + # group to ((32, 4, rest_m), (4, rest_k), l) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + + +def run( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Execute a persistent batched dense blockscaled GEMM operation on Blackwell architecture with performance benchmarking. + + This function prepares input tensors, configures and launches the persistent GEMM kernel, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: Data type for scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: Vector size for scale factor tensor + :type sf_vec_size: int + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param mma_tiler_mn: MMA tiling size. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster shape. + :type cluster_shape_mn: Tuple[int, int] + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float + """ + print(f"Running Sm100 Persistent Dense BlockScaled GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}") + print(f"C dtype: {c_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create tensor A/B/C + a_ref = cutlass_torch.matrix(l, m, k, a_major == "m", cutlass.Float32) + b_ref = cutlass_torch.matrix(l, n, k, b_major == "n", cutlass.Float32) + c_ref = cutlass_torch.matrix(l, m, n, c_major == "m", cutlass.Float32) + + a_tensor, a_torch = cutlass_torch.cute_tensor_like( + a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, b_torch = cutlass_torch.cute_tensor_like( + b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch = cutlass_torch.cute_tensor_like( + c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + # Mark tensor to be byte aligned + a_tensor.mark_compact_shape_dynamic( + mode=1 if a_major == "k" else 0, + stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + b_tensor.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + c_tensor.mark_compact_shape_dynamic( + mode=1 if c_major == "n" else 0, + stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0), + divisibility=2 if c_dtype == cutlass.Float4E2M1FN else 1, + ) + + # Create scale factor tensor SFA/SFB + def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype): + def ceil_div(a, b): + return (a + b - 1) // b + + sf_k = ceil_div(k, sf_vec_size) + ref_shape = (l, mn, sf_k) + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + ref_permute_order = (1, 2, 0) + mma_permute_order = (3, 4, 1, 5, 2, 0) + + # Create f32 ref torch tensor (cpu) + ref_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + ref_shape, + torch.float32, + permute_order=ref_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=1, + max_val=3, + ), + ) + + # Create f32 cute torch tensor (cpu) + cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + mma_shape, + torch.float32, + permute_order=mma_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0, + max_val=1, + ), + ) + + # convert ref f32 tensor to cute f32 tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref_f32_torch_tensor_cpu), + from_dlpack(cute_f32_torch_tensor_cpu), + ) + cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() + + # reshape makes memory contiguous + ref_f32_torch_tensor_cpu = ( + ref_f32_torch_tensor_cpu.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, mn, sf_k, sf_vec_size) + .reshape(l, mn, sf_k * sf_vec_size) + .permute(*ref_permute_order) + ) + # prune to mkl for reference check. + ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :] + + # Create dtype cute torch tensor (cpu) + cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( + cute_f32_torch_tensor_cpu, + dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + + # Convert f32 cute tensor to dtype cute tensor + cute_tensor = cutlass_torch.convert_cute_tensor( + cute_f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=True, + ) + return ref_f32_torch_tensor_cpu, cute_tensor, cute_torch_tensor + + sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( + l, m, k, sf_vec_size, sf_dtype + ) + sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( + l, n, k, sf_vec_size, sf_dtype + ) + + # Configure gemm kernel + gemm = Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + # Compile gemm kernel + compiled_gemm = cute.compile( + gemm, + a_tensor, + b_tensor, + sfa_tensor, + sfb_tensor, + c_tensor, + max_active_clusters, + current_stream, + ) + + # Compute reference result + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_gemm( + a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, current_stream + ) + print("Verifying results...") + res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref) + res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref) + ref = torch.einsum("mkl,nkl->mnl", res_a, res_b) + + # Convert c back to f32 for comparison. + c_ref_device = c_ref.cuda() + cute.testing.convert( + c_tensor, + from_dlpack(c_ref_device, assumed_align=16).mark_layout_dynamic( + leading_dim=(1 if c_major == "n" else 0) + ), + ) + c_ref = c_ref_device.cpu() + + if c_dtype in (cutlass.Float32, cutlass.Float16, cutlass.BFloat16): + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + elif c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN): + # Convert ref : f32 -> f8 -> f32 + ref_f8_ = torch.empty(*(l, m, n), dtype=torch.uint8, device="cuda").permute( + 1, 2, 0 + ) + ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + ref_f8.element_type = c_dtype + ref_device = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0).cuda() + ref_tensor = from_dlpack(ref_device, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + cute.testing.convert(ref_tensor, ref_f8) + cute.testing.convert(ref_f8, ref_tensor) + ref = ref_device.cpu() + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + def generate_tensors(): + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, _ = cutlass_torch.cute_tensor_like( + c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + # Mark tensor to be byte aligned + a_tensor.mark_compact_shape_dynamic( + mode=1 if a_major == "k" else 0, + stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + b_tensor.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + c_tensor.mark_compact_shape_dynamic( + mode=1 if c_major == "n" else 0, + stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0), + divisibility=2 if c_dtype == cutlass.Float4E2M1FN else 1, + ) + + _, sfa_tensor, _ = create_scale_factor_tensor(l, m, k, sf_vec_size, sf_dtype) + _, sfb_tensor, _ = create_scale_factor_tensor(l, n, k, sf_vec_size, sf_dtype) + return cute.testing.JitArguments( + a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, current_stream + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch.numel() * a_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + sfa_torch.numel() * sfa_torch.element_size() + + sfb_torch.numel() * sfb_torch.element_size() + + c_torch.numel() * c_torch.element_size() + ) + workspace_count = cute.testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = cute.testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of Sm100 Dense Persistent BlockScaled GEMM." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(512, 256, 256, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float4E2M1FN) + parser.add_argument("--sf_dtype", type=cutlass.dtype, default=cutlass.Float8E8M0FNU) + parser.add_argument("--sf_vec_size", type=int, default=16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.mnkl, + args.ab_dtype, + args.sf_dtype, + args.sf_vec_size, + args.c_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a8372950cb194521dd474a52a9dfeeff02fa83 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm.py @@ -0,0 +1,1890 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Type, Tuple, Union +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +4. Type convert C matrix to output type. +5. Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. +6. Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/dense_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The Blackwell tcgen05 MMA tile shape used 2 cta with 256x128 +MMA tile and the cluster shape is (2,1). The input, mma accumulator and output +data type are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +Constraints: +* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), + see detailed valid dtype combinations in below DenseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) +* Mma tiler N must be 32-256, step 32 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if use_2cta_instrs=True +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, + Float16/BFloat16, and Int8/Uint8/Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class DenseGemmKernel: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Blackwell GPUs. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tiler (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results + :type use_tma_store: bool + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported A/B data types: + - TFloat32 + - Float16/BFloat16 + - Int8/Uint8 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulator data types: + - Float32 (for all floating point A/B data types) + - Float16 (only for fp16 and fp8 A/B data types) + - Int32 (only for uint8/int8 A/B data types) + + :note: Supported C data types: + - Float32 (for float32 and int32 accumulator data types) + - Int32 (for float32 and int32 accumulator data types) + - Float16/BFloat16 (for fp16 and fp8 accumulator data types) + - Int8/Uint8 (for uint8/int8 accumulator data types) + - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) + + :note: Constraints: + - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) + - MMA tiler N must be 32-256, step 32 + - Cluster shape M must be multiple of 2 if use_2cta_instrs=True + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = DenseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + 3. Output C tensor store mode: + - use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor. + :type use_tma_store: bool + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_tma_store = use_tma_store + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + self.threads_per_cta = 128 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + if cutlass.const_expr(self.use_tma_store): + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + else: + self.epi_tile = self.cta_tile_shape_mnk[:2] + + # Setup A/B/C stage count in shared memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + self.use_tma_store, + ) + + # Compute A/B/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = ( + sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + if self.use_tma_store + else None + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup store for C + tma_atom_c = None + tma_tensor_c = None + if cutlass.const_expr(self.use_tma_store): + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + # Compute grid size + grid = self._compute_grid(c, self.cta_tile_shape_mnk, self.cluster_shape_mn) + + self.buffer_align_bytes = 1024 + + c_smem_size = ( + cute.cosize(self.c_smem_layout_staged.outer) if self.use_tma_store else 0 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c if self.use_tma_store else c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma descriptor + # + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coords inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == 0: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = ( + storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + if self.use_tma_store + else None + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + + # + # Alloc tensor memory buffer + # + if warp_idx == 0: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + cute.arch.barrier() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf + ) + # (MMA, MMA_M, MMA_N) + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( + tidx, tCtAcc, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC = None + tTR_gC = None + if cutlass.const_expr(self.use_tma_store): + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC = self.epilog_gmem_copy_and_partition( + tidx, tma_atom_c, tCgC, epi_tile, sC + ) + else: + simt_atom, tTR_rC, tTR_gC = self.epilog_gmem_copy_and_partition( + tidx, tiled_copy_t2r, tCgC, epi_tile, sC + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + if cutlass.const_expr(self.use_tma_store): + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC[(None, None, None, *mma_tile_coord_mnl)] + else: + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC[(None, None, None, None, None, *mma_tile_coord_mnl)] + + # + # Pipelining TMA load A/B and MMA mainloop + # + prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt) + + if warp_idx == 0: + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Prefetch TMA load A/B + # + for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # Peek (try_wait) AB buffer full for k_tile = 0 + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # MMA mainloop + # + for k_tile in range(k_tile_cnt): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + if ab_producer_state.count < k_tile_cnt: + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # Async arrive accumulator buffer full + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # + # Epilogue + # + + # Release tensor memory allocation lock + if warp_idx == 0: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + + # Wait for accumulator buffer full + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + if cutlass.const_expr(self.use_tma_store): + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + else: + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + c_pipeline = None + if cutlass.const_expr(self.use_tma_store): + # Initialize tma store c_pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + if cutlass.const_expr(self.use_tma_store): + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = subtile_idx % self.num_c_stage + cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.barrier() + + # + # TMA store C to global memory + # + if warp_idx == 0: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure TMA store is completed to recollect C buffer + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier() + else: + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # + # Store C to global memory + # + cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) + + # + # Dealloc the tensor memory buffer + # + cute.arch.barrier() + if warp_idx == 0: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + + # + # Wait for C store complete + # + if cutlass.const_expr(self.use_tma_store): + c_pipeline.producer_tail() + + # + # Wait A/B buffer empty + # + if warp_idx == 0: + # Reverse prefetch_k_tile_cnt times to next available buffer + for i in range(prefetch_k_tile_cnt): + ab_producer_state.reverse() + ab_pipeline.producer_tail(ab_producer_state) + return + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing either: + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: + - simt_atom: The SIMT copy atom + - tTR_rC: The register tensor C + - tTR_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + if cutlass.const_expr(self.use_tma_store): + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + else: + tiled_copy_t2r = atom + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_gC = thr_copy_t2r.partition_D(gC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + return simt_atom, tTR_rC, tTR_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + use_tma_store: bool, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tile. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C in global memory. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, epilogue stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 1 + # Default C stages + num_c_stage = 2 if use_tma_store else 0 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + if use_tma_store + else None + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = ( + cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + if use_tma_store + else 0 + ) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if use_tma_store: + num_c_stage += ( + smem_capacity + - ab_bytes_per_stage * num_ab_stage + - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ((occupancy + 1) * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + ) -> Tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + cluster_shape_mnl = (*cluster_shape_mn, 1) + + grid = cute.round_up( + ( + cute.ceil_div(c.layout.shape[0], cta_tile_shape_mnk[0]), + cute.ceil_div(c.layout.shape[1], cta_tile_shape_mnk[1]), + c.layout.shape[2], + ), + cluster_shape_mnl, + ) + + return grid + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, mma_tiler: Tuple[int, int, int] + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + return sm100_utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if ( + acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} + or acc_dtype == cutlass.Float16 + and ab_dtype + not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} + or acc_dtype == cutlass.Int32 + and ab_dtype not in {cutlass.Uint8, cutlass.Int8} + ): + is_valid = False + if ( + acc_dtype == cutlass.Float32 + and c_dtype + not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + or acc_dtype == cutlass.Float16 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + } + or acc_dtype == cutlass.Int32 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + cutlass.Float32, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + if mma_tiler_mn[1] not in range(32, 257, 32): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_epilog_store_option( + use_2cta_instrs: bool, + use_tma_store: bool, + m: int, + n: int, + mma_tiler_mn: Tuple[int, int], + ) -> bool: + """ + Check if the epilogue store option is valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + + :return: True if the epilogue store option is valid, False otherwise + :rtype: bool + """ + + is_valid = True + # None TMA store version does not have predication, can not support OOB tiles + cta_tile_shape_mn = ( + mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), + mma_tiler_mn[1], + ) + if not use_tma_store: + if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not DenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not DenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not DenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid epilogue store option + if not DenseGemmKernel.is_valid_epilog_store_option( + use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn + ): + can_implement = False + return can_implement + + +def run_dense_gemm( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_2cta_instrs: bool, + use_tma_store: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + measure_launch_overhead=False, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """ + print(f"Running B100 Dense GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True' if use_tma_store else 'False'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not DenseGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor + + a_ref, a_tensor, a_torch = create_and_permute_tensor( + l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True + ) + b_ref, b_tensor, b_torch = create_and_permute_tensor( + l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True + ) + c_ref, c_tensor, c_torch = create_and_permute_tensor( + l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True + ) + + # Configure gemm kernel + gemm = DenseGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + ) + + torch_stream = torch.cuda.Stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile(gemm, a_tensor, b_tensor, c_tensor, stream) + + # Launch GPU kernel + # Warm up + for i in range(warmup_iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + # Execution + for i in range(iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + + # Compute reference result + if not skip_ref_check: + if ab_dtype in { + cutlass.Int8, + cutlass.Uint8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) + else: + ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() + + # Copy gpu result back + gpu_c = c_torch.cpu() + + # Convert ref to c_type + if c_dtype == cutlass.Float32: + ref_c = ref + elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # m major: (l, n, m) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + # Reference checking ref_c and gpu_c + torch.testing.assert_close( + gpu_c, + ref_c, + atol=tolerance, + rtol=1e-05, + ) + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + # or: return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of MxNxKxL GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tiler (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--use_tma_store", action="store_true", help="Use tma store or not" + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument("--iterations", type=int, default=1, help="Iterations") + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run_dense_gemm( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.use_tma_store, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py new file mode 100644 index 0000000000000000000000000000000000000000..f5022a82966bc93da97bbde24ab29cb598d5a914 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py @@ -0,0 +1,2193 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Type, Tuple, Union + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.cute.testing as testing +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + + +""" +A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Type convert C matrix to output type. + - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. + - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +Input arguments to this example is same as dense_gemm.py. + +.. code-block:: bash + + python examples/blackwell/dense_gemm_persistent.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_gemm_persistent.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), + see detailed valid dtype combinations in below PersistentDenseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) +* Mma tiler N must be 32-256, step 32 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if use_2cta_instrs=True +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, + Float16/BFloat16, and Int8/Uint8/Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class PersistentDenseGemmKernel: + """This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results + :type use_tma_store: bool + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported A/B data types: + - TFloat32 + - Float16/BFloat16 + - Int8/Uint8 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulator data types: + - Float32 (for all floating point A/B data types) + - Float16 (only for fp16 and fp8 A/B data types) + - Int32 (only for uint8/int8 A/B data types) + + :note: Supported C data types: + - Float32 (for float32 and int32 accumulator data types) + - Int32 (for float32 and int32 accumulator data types) + - Float16/BFloat16 (for fp16 and fp8 accumulator data types) + - Int8/Uint8 (for uint8/int8 accumulator data types) + - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) + + :note: Constraints: + - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) + - MMA tiler N must be 32-256, step 32 + - Cluster shape M must be multiple of 2 if use_2cta_instrs=True + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = PersistentDenseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + 3. Output C tensor store mode: + - use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor. + :type use_tma_store: bool + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_tma_store = use_tma_store + + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier id for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_bar_id = 0 + self.epilog_sync_bar_id = 1 + self.tmem_ptr_sync_bar_id = 2 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + if cutlass.const_expr(self.use_tma_store): + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + else: + self.epi_tile = self.cta_tile_shape_mnk[:2] + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + self.use_tma_store, + ) + + # Compute A/B/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = ( + sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + if cutlass.const_expr(self.use_tma_store) + else None + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler, self.num_acc_stage + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + if cutlass.const_expr(self.use_tma_store): + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + c_smem_size = ( + cute.cosize(self.c_smem_layout_staged.outer) + if cutlass.const_expr(self.use_tma_store) + else 0 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c if cutlass.const_expr(self.use_tma_store) else c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = ( + storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + if cutlass.const_expr(self.use_tma_store) + else None + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta + ) + + # + # Specialized TMA load warp + # + + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(k_tile_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_gC_partitioned = None + if cutlass.const_expr(self.use_tma_store): + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + else: + ( + simt_atom, + tTR_rC, + tTR_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + c_pipeline = None + if cutlass.const_expr(self.use_tma_store): + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + bSG_gC = None + tTR_gC = None + if cutlass.const_expr(self.use_tma_store): + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + else: + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC_partitioned[ + ( + None, + None, + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + if cutlass.const_expr(self.use_tma_store): + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + else: + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + if cutlass.const_expr(self.use_tma_store): + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + else: + # + # Convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # + # Store C to global memory + # + cute.copy( + simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)] + ) + + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads + ) + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + # + # Wait for C store complete + # + if cutlass.const_expr(self.use_tma_store): + c_pipeline.producer_tail() + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing either: + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: + - simt_atom: The SIMT copy atom + - tTR_rC: The register tensor C + - tTR_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + if cutlass.const_expr(self.use_tma_store): + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + else: + tiled_copy_t2r = atom + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_gC = thr_copy_t2r.partition_D(gC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + return simt_atom, tTR_rC, tTR_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + use_tma_store: bool, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 2 + + # Default C stages + num_c_stage = 2 if use_tma_store else 0 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + if use_tma_store + else None + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = ( + cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + if use_tma_store + else 0 + ) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if use_tma_store: + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, + mma_tiler: Tuple[int, int, int], + num_acc_stage: int, + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + :param num_acc_stage: The stage of the accumulator tensor. + :type num_acc_stage: int + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage)) + num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + return num_tmem_alloc_cols + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if ( + acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} + or acc_dtype == cutlass.Float16 + and ab_dtype + not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} + or acc_dtype == cutlass.Int32 + and ab_dtype not in {cutlass.Uint8, cutlass.Int8} + ): + is_valid = False + if ( + acc_dtype == cutlass.Float32 + and c_dtype + not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + or acc_dtype == cutlass.Float16 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + } + or acc_dtype == cutlass.Int32 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + cutlass.Float32, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + if mma_tiler_mn[1] not in range(32, 257, 32): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_epilog_store_option( + use_2cta_instrs: bool, + use_tma_store: bool, + m: int, + n: int, + mma_tiler_mn: Tuple[int, int], + ) -> bool: + """ + Check if the epilogue store option is valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + + :return: True if the epilogue store option is valid, False otherwise + :rtype: bool + """ + + is_valid = True + # None TMA store version does not have predication, can not support OOB tiles + cta_tile_shape_mn = ( + mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), + mma_tiler_mn[1], + ) + if not use_tma_store: + if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not PersistentDenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not PersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not PersistentDenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid epilogue store option + if not PersistentDenseGemmKernel.is_valid_epilog_store_option( + use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn + ): + can_implement = False + return can_implement + + +def run( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Tuple[int, int] = (2, 1), + use_2cta_instrs: bool = True, + use_tma_store: bool = True, + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking. + + This function prepares input tensors, configures and launches the persistent GEMM kernel, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param acc_dtype: Data type for accumulation during matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the + default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type mma_tiler_mn: Tuple[int, int], optional + :param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the + default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type cluster_shape_mn: Tuple[int, int], optional + :param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner + will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters. + :type use_2cta_instrs: bool, optional + :param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use + the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters. + :type use_tma_store: bool, optional + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float + """ + print(f"Running Blackwell Persistent Dense GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True' if use_tma_store else 'False'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not PersistentDenseGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu + + a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor( + l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True + ) + b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor( + l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True + ) + c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor( + l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True + ) + + # Configure gemm kernel + gemm = PersistentDenseGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile( + gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream + ) + + if not skip_ref_check: + compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream) + if ab_dtype in { + cutlass.Int8, + cutlass.Uint8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) + else: + ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() + + # Copy gpu result back + gpu_c = c_torch.cpu() + + # Convert ref to c_type + if c_dtype == cutlass.Float32: + ref_c = ref + elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # m major: (l, n, m) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + # Reference checking ref_c and gpu_c + torch.testing.assert_close( + gpu_c, + ref_c, + atol=tolerance, + rtol=1e-05, + ) + + def generate_tensors(): + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, _ = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_cpu.numel() * a_torch_cpu.element_size() + + b_torch_cpu.numel() * b_torch_cpu.element_size() + + c_torch_cpu.numel() * c_torch_cpu.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of Dense Persistent GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--use_tma_store", action="store_true", help="Use tma store or not" + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.use_tma_store, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5e172efaf7c7d13627a8fd7c7e5f05b8d90d90 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py @@ -0,0 +1,1831 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Type, Tuple, Union +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL with compiler generated software pipeline. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +4. Type convert C matrix to output type. +5. Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. +6. Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/dense_gemm_software_pipeline.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The Blackwell tcgen05 MMA tile shape used 2 cta with 256x128 +MMA tile and the cluster shape is (2,1). The input, mma accumulator and output +data type are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_gemm_software_pipeline.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +Constraints: +* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), + see detailed valid dtype combinations in below DenseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) +* Mma tiler N must be 32-256, step 32 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if use_2cta_instrs=True +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, + Float16/BFloat16, and Int8/Uint8/Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class DenseGemmKernel: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Blackwell GPUs. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tiler (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results + :type use_tma_store: bool + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported A/B data types: + - TFloat32 + - Float16/BFloat16 + - Int8/Uint8 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulator data types: + - Float32 (for all floating point A/B data types) + - Float16 (only for fp16 and fp8 A/B data types) + - Int32 (only for uint8/int8 A/B data types) + + :note: Supported C data types: + - Float32 (for float32 and int32 accumulator data types) + - Int32 (for float32 and int32 accumulator data types) + - Float16/BFloat16 (for fp16 and fp8 accumulator data types) + - Int8/Uint8 (for uint8/int8 accumulator data types) + - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) + + :note: Constraints: + - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) + - MMA tiler N must be 32-256, step 32 + - Cluster shape M must be multiple of 2 if use_2cta_instrs=True + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = DenseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + 3. Output C tensor store mode: + - use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor. + :type use_tma_store: bool + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_tma_store = use_tma_store + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + self.threads_per_cta = 128 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + if cutlass.const_expr(self.use_tma_store): + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + else: + self.epi_tile = self.cta_tile_shape_mnk[:2] + + # Setup A/B/C stage count in shared memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + self.use_tma_store, + ) + + # Compute A/B/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = ( + sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + if self.use_tma_store + else None + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup store for C + tma_atom_c = None + tma_tensor_c = None + if cutlass.const_expr(self.use_tma_store): + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + # Compute grid size + grid = self._compute_grid(c, self.cta_tile_shape_mnk, self.cluster_shape_mn) + + self.buffer_align_bytes = 1024 + + c_smem_size = ( + cute.cosize(self.c_smem_layout_staged.outer) if self.use_tma_store else 0 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c if self.use_tma_store else c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma descriptor + # + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coords inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == 0: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = ( + storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + if self.use_tma_store + else None + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + + # + # Alloc tensor memory buffer + # + if warp_idx == 0: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + cute.arch.barrier() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf + ) + # (MMA, MMA_M, MMA_N) + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( + tidx, tCtAcc, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC = None + tTR_gC = None + if cutlass.const_expr(self.use_tma_store): + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC = self.epilog_gmem_copy_and_partition( + tidx, tma_atom_c, tCgC, epi_tile, sC + ) + else: + simt_atom, tTR_rC, tTR_gC = self.epilog_gmem_copy_and_partition( + tidx, tiled_copy_t2r, tCgC, epi_tile, sC + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + if cutlass.const_expr(self.use_tma_store): + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC[(None, None, None, *mma_tile_coord_mnl)] + else: + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC[(None, None, None, None, None, *mma_tile_coord_mnl)] + + # /////////////////////////////////////////////////////////////////////////////// + # MAINLOOP + # /////////////////////////////////////////////////////////////////////////////// + prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt) + if warp_idx == 0: + for k_tile in cutlass.range( + k_tile_cnt, + prefetch_stages=self.num_ab_stage - 2, + ): + # wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + if is_leader_cta: + # Wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + ab_producer_state.advance() + ab_consumer_state.advance() + + # Async arrive accumulator buffer full + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # + # Epilogue + # + + # Release tensor memory allocation lock + if warp_idx == 0: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + + # Wait for accumulator buffer full + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + if cutlass.const_expr(self.use_tma_store): + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + else: + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + c_pipeline = None + if cutlass.const_expr(self.use_tma_store): + # Initialize tma store c_pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + if cutlass.const_expr(self.use_tma_store): + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = subtile_idx % self.num_c_stage + cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.barrier() + + # + # TMA store C to global memory + # + if warp_idx == 0: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure TMA store is completed to recollect C buffer + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier() + else: + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # + # Store C to global memory + # + cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) + + # + # Dealloc the tensor memory buffer + # + cute.arch.barrier() + if warp_idx == 0: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + + # + # Wait for C store complete + # + if cutlass.const_expr(self.use_tma_store): + c_pipeline.producer_tail() + + # + # Wait A/B buffer empty + # + if warp_idx == 0: + # Reverse prefetch_k_tile_cnt times to next available buffer + for i in range(prefetch_k_tile_cnt): + ab_producer_state.reverse() + ab_pipeline.producer_tail(ab_producer_state) + return + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing either: + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: + - simt_atom: The SIMT copy atom + - tTR_rC: The register tensor C + - tTR_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + if cutlass.const_expr(self.use_tma_store): + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + else: + tiled_copy_t2r = atom + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_gC = thr_copy_t2r.partition_D(gC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + return simt_atom, tTR_rC, tTR_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + use_tma_store: bool, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tile. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C in global memory. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, epilogue stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 1 + # Default C stages + num_c_stage = 2 if use_tma_store else 0 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + if use_tma_store + else None + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = ( + cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + if use_tma_store + else 0 + ) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if use_tma_store: + num_c_stage += ( + smem_capacity + - ab_bytes_per_stage * num_ab_stage + - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ((occupancy + 1) * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + ) -> Tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + cluster_shape_mnl = (*cluster_shape_mn, 1) + + grid = cute.round_up( + ( + cute.ceil_div(c.layout.shape[0], cta_tile_shape_mnk[0]), + cute.ceil_div(c.layout.shape[1], cta_tile_shape_mnk[1]), + c.layout.shape[2], + ), + cluster_shape_mnl, + ) + + return grid + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, mma_tiler: Tuple[int, int, int] + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + return sm100_utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if ( + acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} + or acc_dtype == cutlass.Float16 + and ab_dtype + not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} + or acc_dtype == cutlass.Int32 + and ab_dtype not in {cutlass.Uint8, cutlass.Int8} + ): + is_valid = False + if ( + acc_dtype == cutlass.Float32 + and c_dtype + not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + or acc_dtype == cutlass.Float16 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + } + or acc_dtype == cutlass.Int32 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + cutlass.Float32, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + if mma_tiler_mn[1] not in range(32, 257, 32): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_epilog_store_option( + use_2cta_instrs: bool, + use_tma_store: bool, + m: int, + n: int, + mma_tiler_mn: Tuple[int, int], + ) -> bool: + """ + Check if the epilogue store option is valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + + :return: True if the epilogue store option is valid, False otherwise + :rtype: bool + """ + + is_valid = True + # None TMA store version does not have predication, can not support OOB tiles + cta_tile_shape_mn = ( + mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), + mma_tiler_mn[1], + ) + if not use_tma_store: + if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not DenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not DenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not DenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid epilogue store option + if not DenseGemmKernel.is_valid_epilog_store_option( + use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn + ): + can_implement = False + return can_implement + + +def run_dense_gemm( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_2cta_instrs: bool, + use_tma_store: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """ + print(f"Running B100 software pipeline Dense GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True' if use_tma_store else 'False'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not DenseGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor + + a_ref, a_tensor, a_torch = create_and_permute_tensor( + l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True + ) + b_ref, b_tensor, b_torch = create_and_permute_tensor( + l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True + ) + c_ref, c_tensor, c_torch = create_and_permute_tensor( + l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True + ) + + # Configure gemm kernel + gemm = DenseGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + ) + + torch_stream = torch.cuda.Stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile(gemm, a_tensor, b_tensor, c_tensor, stream) + + # Launch GPU kernel + # Warm up + for i in range(warmup_iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + # Execution + for i in range(iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + + # Compute reference result + if not skip_ref_check: + if ab_dtype in { + cutlass.Int8, + cutlass.Uint8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) + else: + ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() + + # Copy gpu result back + gpu_c = c_torch.cpu() + + # Convert ref to c_type + if c_dtype == cutlass.Float32: + ref_c = ref + elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # m major: (l, n, m) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + # Reference checking ref_c and gpu_c + torch.testing.assert_close( + gpu_c, + ref_c, + atol=tolerance, + rtol=1e-05, + ) + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + # or: return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of MxNxKxL GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tiler (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--use_tma_store", action="store_true", help="Use tma store or not" + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument("--iterations", type=int, default=1, help="Iterations") + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run_dense_gemm( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.use_tma_store, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/fmha.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/fmha.py new file mode 100644 index 0000000000000000000000000000000000000000..560129dfdcd4d065710b4d0bb08e638df3d51acc --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/fmha.py @@ -0,0 +1,3100 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import enum +import math +import time +from typing import Type, Tuple + +import torch +import torch.nn.functional as F +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.cute.testing as testing +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.typing import Int32, Int64, Float32, Boolean + +""" +A fused multi-head attention (FMHA) example for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of fused multi-head attention using a TMA + Blackwell SM100 +TensorCore warp-specialized persistent kernel. The implementation integrates the Q*K^T matrix multiplication, +softmax normalization, and softmax(Q*K^T)*V into a single kernel, avoiding intermediate data movement between +global memory and shared memory, thus improving computational efficiency. + +The kernel implements key optimizations including: +- Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) +- Pipeline stages between different warps for overlapping computation and memory access +- Support for different precision data types +- Optional causal masking for autoregressive models + +To run this example: + +.. code-block:: bash + + python examples/blackwell/fmha.py \ + --qk_acc_dtype Float32 --pv_acc_dtype Float32 \ + --mma_tiler_mn 128,128 \ + --q_shape 4,1024,8,64 --k_shape 4,1024,8,64 \ + --is_persistent + +The above example runs FMHA with batch size 4, sequence length 1024, 8 attention heads, and head +dimension 64. The Blackwell tcgen05 MMA tile shape is (128, 128), and the kernel uses fp16 for input/output +with fp32 for accumulation. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/fmha.py \ + --qk_acc_dtype Float32 --pv_acc_dtype Float32 \ + --mma_tiler_mn 128,128 \ + --q_shape 4,1024,8,64 --k_shape 4,1024,8,64 \ + --is_persistent --warmup_iterations 10 \ + --iterations 10 --skip_ref_check + +Constraints for this example: +* Supported head dimensions: 32, 64, and 128 +* Number of heads in Q must be divisible by number of heads in K +* mma_tiler_mn must be 128,128 +* Batch size must be the same for Q, K, and V tensors +* For causal masking, use --is_causal (note: specify without =True/False) +* For persistent scheduling, use --is_persistent (note: specify without =True/False) +""" + +class FmhaStaticTileSchedulerParams: + def __init__( + self, + is_persistent: bool, + problem_shape_mbh: cute.Shape, + *, + loc=None, + ip=None, + ): + self.is_persistent = is_persistent + self.problem_shape_mbh = problem_shape_mbh + self._loc = loc + self._ip = ip + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.is_persistent, self.problem_shape_mbh]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.is_persistent, self.problem_shape_mbh], self._values_pos + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + +def create_fmha_static_tile_scheduler_params( + is_persistent: bool, + problem_shape_mbh: cute.Shape, +) -> FmhaStaticTileSchedulerParams: + return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh) + +class FmhaStaticTileScheduler: + + def __init__( + self, + params: FmhaStaticTileSchedulerParams, + current_work_linear_idx: Int32, + blk_coord: cute.Coord, + grid_shape: cute.Shape, + *, + loc=None, + ip=None, + ): + self._params = params + self._blk_coord = blk_coord + self._grid_shape = grid_shape + self._is_persistent = params.is_persistent + self._current_work_linear_idx = current_work_linear_idx + self._problem_shape_mbh = cute.make_layout( + params.problem_shape_mbh, loc=loc, ip=ip + ) + self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) + self._is_first_block = True + self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) + self._loc = loc + self._ip = ip + + # called by host + @staticmethod + def get_grid_shape( + params: FmhaStaticTileSchedulerParams, + *, + loc=None, + ip=None, + ) -> cute.Shape: + if params.is_persistent: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return ( + cutlass.min( + sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip) + ), + 1, + 1, + ) + else: + return params.problem_shape_mbh + + @staticmethod + def check_valid_work_for_seqlen_q( + q_tiler: int, + current_idx: Int32, + seqlen_q: Int32, + ) -> Boolean: + return current_idx * q_tiler < seqlen_q + + def get_current_work(self, *, loc=None, ip=None) -> utils.WorkTileInfo: + is_valid = ( + self._current_work_linear_idx < self._num_blocks + if self._is_persistent + else self._is_first_block + ) + + blk_coord = (0, 0, 0) + if self._is_persistent: + blk_coord = self._problem_shape_mbh.get_hier_coord( + self._current_work_linear_idx, loc=loc, ip=ip + ) + else: + blk_coord = self._blk_coord + + # cur_tile_coord is (mid, 0, (bid, hid)) + cur_tile_coord = ( + blk_coord[0], + 0, + (blk_coord[1], blk_coord[2]), + ) + + return utils.WorkTileInfo(cur_tile_coord, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): + if self._is_persistent: + self._current_work_linear_idx += advance_count * self.num_persistent_sm + self._is_first_block = False + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self._params) + values.extend(cutlass.extract_mlir_values(self._current_work_linear_idx)) + values.extend(cutlass.extract_mlir_values(self._blk_coord)) + values.extend(cutlass.extract_mlir_values(self._grid_shape)) + return values + + def __new_from_mlir_values__(self, values): + assert len(values) == 10 + new_params = cutlass.new_from_mlir_values(self._params, values[0:3]) + new_current_work_linear_idx = cutlass.new_from_mlir_values( + self._current_work_linear_idx, [values[3]] + ) + new_blk_coord = cutlass.new_from_mlir_values(self._blk_coord, values[4:7]) + new_grid_shape = cutlass.new_from_mlir_values(self._grid_shape, values[7:]) + return FmhaStaticTileScheduler( + new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape + ) + + +def create_fmha_static_tile_scheduler( + params: FmhaStaticTileSchedulerParams, + blk_coord: cute.Coord, + grid_shape: cute.Shape, +) -> FmhaStaticTileScheduler: + return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) + + +class MaskType(enum.Enum): + NO_MASK = enum.auto() + RESIDUAL_MASK = enum.auto() + CAUSAL_MASK = enum.auto() + + +def make_thread_cooperative_group(size: int): + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size, size) + + +class BlackwellFusedMultiHeadAttentionForward: + def __init__( + self, + qk_acc_dtype: Type[cutlass.Numeric], + pv_acc_dtype: Type[cutlass.Numeric], + mma_tiler: Tuple[int, int, int], + is_persistent: bool, + mask_type: MaskType, + ): + """Initializes the configuration for a Blackwell Fused Multi-Head Attention (FMHA) kernel. + + This configuration includes several key aspects: + + 1. Data Type Settings: + - qk_acc_dtype: Data type for Q*K^T matrix multiplication accumulator + - pv_acc_dtype: Data type for P*V matrix multiplication accumulator + + 2. MMA Instruction Settings: + - mma_tiler: The (M, N, K) shape of the MMA instruction unit + - qk_mma_tiler: MMA shape for Q*K^T computation + - pv_mma_tiler: MMA shape for P*V computation + + 3. Kernel Execution Mode: + - is_persistent: Boolean indicating whether to use persistent kernel mode + - mask_type: Specifies the type of mask to use (no mask, residual mask, or causal mask) + + :param qk_acc_dtype: Data type for Q*K^T matrix multiplication accumulator + :type qk_acc_dtype: Type[cutlass.Numeric] + :param pv_acc_dtype: Data type for P*V matrix multiplication accumulator + :type pv_acc_dtype: Type[cutlass.Numeric] + :param mma_tiler: The (M, N, K) shape of the MMA instruction + :type mma_tiler: Tuple[int, int, int] + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param mask_type: Type of mask to use + :type mask_type: MaskType + """ + + self.qk_acc_dtype = qk_acc_dtype + self.pv_acc_dtype = pv_acc_dtype + self.cta_tiler = ( + 2 * mma_tiler[0], # 2 Q tile per CTA + mma_tiler[1], + mma_tiler[2], + ) + self.qk_mma_tiler = mma_tiler + self.pv_mma_tiler = ( + mma_tiler[0], + mma_tiler[2], + mma_tiler[1], + ) + self.cluster_shape_mn = (1, 1) + self.is_persistent = is_persistent + self.mask_type = mask_type + self.softmax0_warp_ids = (0, 1, 2, 3) + self.softmax1_warp_ids = (4, 5, 6, 7) + self.correction_warp_ids = (8, 9, 10, 11) + self.mma_warp_id = 12 + self.load_warp_id = 13 + self.epilogue_warp_id = 14 + self.empty_warp_id = 15 + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + self.load_warp_id, + self.epilogue_warp_id, + self.empty_warp_id, + ) + ) + + self.cta_sync_bar_id = 0 + self.tmem_alloc_sync_bar_id = 1 + + self.tmem_s0_offset = 0 + self.tmem_s1_offset = 128 + self.tmem_o0_offset = 256 + self.tmem_o1_offset = 384 + self.tmem_p0_offset = 32 + self.tmem_p1_offset = 160 + + # vec buffer for row_max & row_sum + self.tmem_vec0_offset = 0 + self.tmem_vec1_offset = 128 + + self.num_regs_softmax = 192 + self.num_regs_correction = 96 + self.num_regs_other = 32 + self.num_regs_empty = 24 + + self.buffer_align_bytes = 1024 + + num_warps_per_warpgroup = 4 + self.softmax_warpgroup_count = ( + len((*self.softmax0_warp_ids, *self.softmax1_warp_ids)) + // num_warps_per_warpgroup + ) + + def _setup_attributes(self): + """Set up configurations and parameters for the FMHA kernel operation. + + This method initializes and configures various attributes required for the + execution of the fused multi-head attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.q_stage = 2 + self.kv_stage = 4 if self.q_dtype.width == 8 else 3 + self.acc_stage = 1 + self.softmax_corr_stage = 1 + self.mma_corr_stage = 2 + self.mma_softmax_stage = 1 + self.epi_stage = 2 + + @cute.jit + def __call__( + self, + q_iter: cute.Pointer, + k_iter: cute.Pointer, + v_iter: cute.Pointer, + o_iter: cute.Pointer, + problem_size: Tuple[Int32, Int32, Int32, Int32, Int32, Int32], + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + scale_output: Float32, + stream: cuda.CUstream, + ): + """Execute the Fused Multi-Head Attention operation on the provided tensors. + + This method prepares the input tensors for processing, validates their shapes and types, + configures the computation parameters, and launches the CUDA kernel. + + The method handles: + 1. Tensor layout transformations for specific memory access patterns + 2. Validation of tensor shapes and data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch with appropriate parameters + + :param q_iter: The query tensor pointer + :type q_iter: cute.Pointer + :param k_iter: The key tensor pointer + :type k_iter: cute.Pointer + :param v_iter: The value tensor pointer + :type v_iter: cute.Pointer + :param o_iter: The output tensor pointer + :type o_iter: cute.Pointer + :param problem_size: The problem size with shape [b, s_q, s_k, h_q, h_k, d]. If cum_seqlen_q or cum_seqlen_k is not None, s_q and s_k are the max of the cumulative sequence length respectively. + :type problem_size: Tuple[Int32, Int32, Int32, Int32, Int32, Int32] + :param cum_seqlen_q: The cumulative sequence length tensor for query + :type cum_seqlen_q: cute.Tensor | None + :param cum_seqlen_k: The cumulative sequence length tensor for key + :type cum_seqlen_k: cute.Tensor | None + :param scale_softmax_log2: The log2 scale factor for softmax + :type scale_softmax_log2: Float32 + :param scale_output: The scale factor for the output + :type scale_output: Float32 + :param stream: The CUDA stream to execute the kernel on + :type stream: cuda.CUstream + :raises TypeError: If tensor data types don't match or aren't supported + :raises RuntimeError: If tensor layouts aren't in supported formats + """ + b, s_q, s_k, h_q, h_k, d = problem_size + h_r = h_q // h_k + qo_offset = 0 if cum_seqlen_q is None else -s_q * d * h_r * h_k + kv_offset = 0 if cum_seqlen_k is None else -s_k * d * h_k + b_qo = b if cum_seqlen_q is None else s_q * (1 + b) + b_kv = b if cum_seqlen_k is None else s_k * (1 + b) + stride_b_qo = h_r * h_k * s_q * d if cum_seqlen_q is None else d * h_r * h_k + stride_b_kv = h_k * s_k * d if cum_seqlen_k is None else d * h_k + + # (s, d, ((h_r, h_k), b)) + q_layout = cute.make_layout( + (s_q, d, ((h_r, h_k), b_qo)), + stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)), + ) + q = cute.make_tensor(q_iter + qo_offset, q_layout) + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + k_layout = cute.make_layout( + (s_k, d, ((h_r, h_k), b_kv)), + stride=(d * h_k, 1, ((0, d), stride_b_kv)), + ) + k = cute.make_tensor(k_iter + kv_offset, k_layout) + # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast + v_layout = cute.make_layout( + (d, s_k, ((h_r, h_k), b_kv)), + stride=(1, d * h_k, ((0, d), stride_b_kv)), + ) + v = cute.make_tensor(v_iter + kv_offset, v_layout) + # (s, d, ((h_r, h_k), b)) + o_layout = cute.make_layout( + (s_q, d, ((h_r, h_k), b_qo)), + stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)), + ) + o = cute.make_tensor(o_iter + qo_offset, o_layout) + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q.element_type + self.k_dtype = k.element_type + self.v_dtype = v.element_type + self.o_dtype = o.element_type + + self.tile_sched_params, grid = self._compute_grid( + cute.shape((s_q, d, ((h_r, h_k), b))), + self.cta_tiler, + self.is_persistent, + ) + + self.q_major_mode = utils.LayoutEnum.from_tensor(q).mma_major_mode() + self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode() + self.o_layout = utils.LayoutEnum.from_tensor(o) + + if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of q is not supported") + if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of k is not supported") + if cutlass.const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + raise RuntimeError("The layout of v is not supported") + + # check type consistency + if cutlass.const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if cutlass.const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.ONE + # the intermediate tensor p is from tmem & k-major + p_source = tcgen05.OperandSource.TMEM + p_major_mode = tcgen05.OperandMajorMode.K + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.qk_acc_dtype, + cta_group, + self.qk_mma_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.pv_acc_dtype, + cta_group, + self.pv_mma_tiler[:2], + p_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.pv_mma_tiler[:2] + + q_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.qk_mma_tiler, + self.q_dtype, + self.q_stage, + ) + k_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.qk_mma_tiler, + self.k_dtype, + self.kv_stage, + ) + p_tmem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + self.pv_mma_tiler, + self.q_dtype, + self.acc_stage, + ) + v_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + self.pv_mma_tiler, + self.v_dtype, + self.kv_stage, + ) + o_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.o_dtype, + self.o_layout, + self.epi_tile, + self.epi_stage, + ) + + # TMA load for Q + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_store_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() + + q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q, tma_tensor_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q, + q_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # TMA load for K + k_smem_layout = cute.select(k_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_k, tma_tensor_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k, + k_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + v_smem_layout = cute.select(v_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_v, tma_tensor_v = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + v, + v_smem_layout, + self.pv_mma_tiler, + pv_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + o_cta_v_layout = cute.composition( + cute.make_identity_layout(o.shape), self.epi_tile + ) + o_smem_layout = cute.select(o_smem_layout_staged, mode=[0, 1]) + + tma_atom_o, tma_tensor_o = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_store_op, + o, + o_smem_layout, + o_cta_v_layout, + ) + + q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) + k_copy_size = cute.size_in_bytes(self.k_dtype, k_smem_layout) + self.tma_copy_q_bytes = q_copy_size + self.tma_copy_kv_bytes = k_copy_size + + @cute.struct + class SharedStorage: + # Pipeline barriers + load_q_mbar_ptr: cute.struct.MemRange[Int64, self.q_stage * 2] + load_kv_mbar_ptr: cute.struct.MemRange[Int64, self.kv_stage * 2] + mma_s0_mbar_ptr: cute.struct.MemRange[Int64, self.mma_softmax_stage * 2] + mma_s1_mbar_ptr: cute.struct.MemRange[Int64, self.mma_softmax_stage * 2] + s0_corr_mbar_ptr: cute.struct.MemRange[Int64, self.softmax_corr_stage * 2] + s1_corr_mbar_ptr: cute.struct.MemRange[Int64, self.softmax_corr_stage * 2] + s0_s1_sequence_mbar_ptr: cute.struct.MemRange[ + Int64, self.softmax_warpgroup_count + ] + corr_epi_mbar_ptr: cute.struct.MemRange[Int64, self.epi_stage * 2] + mma_corr_mbar_ptr: cute.struct.MemRange[Int64, self.mma_corr_stage * 2] + tmem_dealloc_mbar_ptr: cute.struct.MemRange[Int64, 1] + # Tmem holding buffer + tmem_holding_buf: Int32 + # Smem tensors + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, cute.cosize(o_smem_layout_staged)], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(q_smem_layout_staged)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(k_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + qk_tiled_mma, + pv_tiled_mma, + tma_atom_q, + tma_tensor_q, + tma_atom_k, + tma_tensor_k, + tma_atom_v, + tma_tensor_v, + tma_atom_o, + tma_tensor_o, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + q_smem_layout_staged, + k_smem_layout_staged, + p_tmem_layout_staged, + v_smem_layout_staged, + o_smem_layout_staged, + self.tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + qk_tiled_mma: cute.TiledMma, + pv_tiled_mma: cute.TiledMma, + tma_atom_q: cute.CopyAtom, + mQ_qdl: cute.Tensor, + tma_atom_k: cute.CopyAtom, + mK_kdl: cute.Tensor, + tma_atom_v: cute.CopyAtom, + mV_dkl: cute.Tensor, + tma_atom_o: cute.CopyAtom, + mO_qdl: cute.Tensor, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + scale_output: Float32, + q_smem_layout_staged: cute.ComposedLayout, + k_smem_layout_staged: cute.ComposedLayout, + p_tmem_layout_staged: cute.ComposedLayout, + v_smem_layout_staged: cute.ComposedLayout, + o_smem_layout_staged: cute.ComposedLayout, + tile_sched_params: FmhaStaticTileSchedulerParams, + ): + """The device kernel implementation of the Fused Multi-Head Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation: + 1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Softmax warps: Compute softmax normalization on attention scores + 4. Correction warps: Apply adjustments to intermediate results + 5. Epilogue warp: Handles final output transformation and storage + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases, and optional attention masking. + + :param qk_tiled_mma: Tiled MMA for Q*K^T + :type qk_tiled_mma: cute.TiledMma + :param pv_tiled_mma: Tiled MMA for P*V + :type pv_tiled_mma: cute.TiledMma + :param tma_atom_q: TMA copy atom for query tensor + :type tma_atom_q: cute.CopyAtom + :param mQ_qdl: Partitioned query tensor + :type mQ_qdl: cute.Tensor + :param tma_atom_k: TMA copy atom for key tensor + :type tma_atom_k: cute.CopyAtom + :param mK_kdl: Partitioned key tensor + :type mK_kdl: cute.Tensor + :param tma_atom_v: TMA copy atom for value tensor + :type tma_atom_v: cute.CopyAtom + :param mV_dkl: Partitioned value tensor + :type mV_dkl: cute.Tensor + :param tma_atom_o: TMA copy atom for output tensor + :type tma_atom_o: cute.CopyAtom + :param mO_qdl: Partitioned output tensor + :type mO_qdl: cute.Tensor + :param scale_softmax_log2: The log2 scale factor for softmax + :type scale_softmax_log2: Float32 + :param scale_output: The scale factor for the output + :type scale_output: Float32 + :param q_smem_layout_staged: Shared memory layout for query tensor + :type q_smem_layout_staged: cute.ComposedLayout + :param k_smem_layout_staged: Shared memory layout for key tensor + :type k_smem_layout_staged: cute.ComposedLayout + :param p_tmem_layout_staged: Tensor memory layout for probability matrix + :type p_tmem_layout_staged: cute.ComposedLayout + :param v_smem_layout_staged: Shared memory layout for value tensor + :type v_smem_layout_staged: cute.ComposedLayout + :param o_smem_layout_staged: Shared memory layout for output tensor + :type o_smem_layout_staged: cute.ComposedLayout + :param tile_sched_params: Scheduling parameters for work distribution + :type tile_sched_params: FmhaStaticTileSchedulerParams + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Prefetch tma desc + # + if warp_idx == self.load_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_k) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_o) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_q_producer, load_q_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.q_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_q_bytes, + barrier_storage=storage.load_q_mbar_ptr.data_ptr(), + ).make_participants() + load_kv_producer, load_kv_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.kv_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_kv_bytes, + barrier_storage=storage.load_kv_mbar_ptr.data_ptr(), + ).make_participants() + mma_s0_producer, mma_s0_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_softmax_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax0_warp_ids) + ), + barrier_storage=storage.mma_s0_mbar_ptr.data_ptr(), + ).make_participants() + mma_s1_producer, mma_s1_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_softmax_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax1_warp_ids) + ), + barrier_storage=storage.mma_s1_mbar_ptr.data_ptr(), + ).make_participants() + s0_corr_producer, s0_corr_consumer = pipeline.PipelineAsync.create( + num_stages=self.softmax_corr_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax0_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.s0_corr_mbar_ptr.data_ptr(), + ).make_participants() + s1_corr_producer, s1_corr_consumer = pipeline.PipelineAsync.create( + num_stages=self.softmax_corr_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax1_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.s1_corr_mbar_ptr.data_ptr(), + ).make_participants() + corr_epi_producer, corr_epi_consumer = pipeline.PipelineAsync.create( + num_stages=self.epi_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len([self.epilogue_warp_id]) + ), + barrier_storage=storage.corr_epi_mbar_ptr.data_ptr(), + ).make_participants() + mma_corr_producer, mma_corr_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_corr_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.mma_corr_mbar_ptr.data_ptr(), + ).make_participants() + s0_s1_sequence_producer, s0_s1_sequence_consumer = ( + pipeline.PipelineAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax0_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax1_warp_ids) + ), + barrier_storage=storage.s0_s1_sequence_mbar_ptr.data_ptr(), + ).make_participants() + ) + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() + + # Correction & Epilogue & tmem barrier init + if warp_idx == self.empty_warp_id: + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, + self.threads_per_warp + * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + ) + ), + ) + cute.arch.mbarrier_init_fence() + + # Generate smem tensor Q/K/V/O + # (MMA, MMA_Q, MMA_D, PIPE) + sQ = storage.sQ.get_tensor( + q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_D, PIPE) + sK = storage.sK.get_tensor( + k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_D, PIPE) + # Strip swizzle info to reuse smem + sV_ptr = cute.recast_ptr(sK.iterator, v_smem_layout_staged.inner) + sV = cute.make_tensor(sV_ptr, v_smem_layout_staged.outer) + sO = storage.sO.get_tensor( + o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner + ) + qk_thr_mma = qk_tiled_mma.get_slice(0) # default 1sm + pv_thr_mma = pv_tiled_mma.get_slice(0) # default 1sm + tSrQ = qk_thr_mma.make_fragment_A(sQ) + tSrK = qk_thr_mma.make_fragment_B(sK) + tOrV = pv_thr_mma.make_fragment_B(sV) + qk_acc_shape = qk_thr_mma.partition_shape_C( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + tStS = qk_thr_mma.make_fragment_C(qk_acc_shape) + pv_acc_shape = pv_thr_mma.partition_shape_C( + (self.pv_mma_tiler[0], self.pv_mma_tiler[1]) + ) + tOtO = pv_thr_mma.make_fragment_C(pv_acc_shape) + + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.tmem_s1_offset, tStS.layout) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + tOtO1 = cute.make_tensor(tOtO.iterator + self.tmem_o1_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + tOrP0 = cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout, + ) + tOrP1 = cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p1_offset, + tOrP.layout, + ) + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, + number_of_threads=self.threads_per_cta, + ) + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.empty_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + continue_cond = False + cuseqlen_q = Int32(0) + seqlen_q = mQ_qdl.shape[0] + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + if not continue_cond: + mQ_qdl_ = mQ_qdl + mK_kdl_ = mK_kdl + mV_dkl_ = mV_dkl + seqlen_k = mK_kdl.shape[0] + curr_block_coord_q = curr_block_coord + curr_block_coord_kv = curr_block_coord + + if cutlass.const_expr(cum_seqlen_q is not None): + logical_offset_mQ = ( + mQ_qdl.shape[0] - seqlen_q, + 0, + (0, cuseqlen_q + seqlen_q), + ) + mQ_qdl_ = cute.domain_offset(logical_offset_mQ, mQ_qdl) + curr_block_coord_q = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], Int32(0)), + ) + + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + logical_offset_mK = ( + mK_kdl.shape[0] - seqlen_k, + 0, + (0, cuseqlen_k + seqlen_k), + ) + logical_offset_mV = ( + 0, + mK_kdl.shape[0] - seqlen_k, + (0, cuseqlen_k + seqlen_k), + ) + mK_kdl_ = cute.domain_offset(logical_offset_mK, mK_kdl) + mV_dkl_ = cute.domain_offset(logical_offset_mV, mV_dkl) + curr_block_coord_kv = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], Int32(0)), + ) + + # Local tile partition global tensors + # (bM, bK, loopM, loopK, loopL) + gQ_qdl = cute.flat_divide( + mQ_qdl_, cute.select(self.qk_mma_tiler, mode=[0, 2]) + ) + tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_q, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + tQgQ = tQgQ_qdl[None, None, 0, curr_block_coord_q[2]] + + gK_kdl = cute.flat_divide( + mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2]) + ) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + tKgK = tKgK_kdl[None, None, 0, curr_block_coord_kv[2]] + + gV_dkl = cute.flat_divide( + mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2]) + ) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + tVgV = tVgV_dkl[None, 0, None, curr_block_coord_kv[2]] + + # Q0 + q0_coord = 2 * curr_block_coord_q[0] + q0_handle = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[None, q0_coord], + tQsQ[None, q0_handle.index], + tma_bar_ptr=q0_handle.barrier, + ) + # K0 + kv_coord = 0 # seqlen_kv_loop + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord], + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, + ) + # Q1 + q1_coord = q0_coord + 1 + q1_handle = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[None, q1_coord], + tQsQ[None, q1_handle.index], + tma_bar_ptr=q1_handle.barrier, + ) + # V0 + v_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, kv_coord], + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, + ) + kv_coord += 1 + + seqlen_kv_loop_steps = ( + self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + - 1 + ) + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # Ki + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord], + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, + ) + # Vi + v_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, kv_coord], + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, + ) + kv_coord += 1 + # End of seqlen_kv loop + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + # Alloc tmem buffer + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.barrier( + barrier_id=self.tmem_alloc_sync_bar_id, + number_of_threads=self.threads_per_warp, + ) + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + continue_cond = False + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + + if not continue_cond: + seqlen_k = mK_kdl.shape[0] + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + # GEMM_QK00 (Q0 * K0 -> S0) + # 1. wait for Q0 + q0_handle = load_q_consumer.wait_and_advance() + tSrQ0 = tSrQ[None, None, None, q0_handle.index] + # 2. wait for K0 + k_handle = load_kv_consumer.wait_and_advance() + tSrK0 = tSrK[None, None, None, k_handle.index] + # 3. acquire empty S0 buffer + s0_handle = mma_s0_producer.acquire_and_advance() + # 4. gemm + num_kphases = cute.size(tSrQ0, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_0 = (None, None, kphase_idx) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + qk_tiled_mma, + tStS0, + tSrQ0[kphase_coord_0], + tSrK0[kphase_coord_0], + tStS0, + ) + # 5. release S0 + s0_handle.commit() + # End of GEMM (Q0 * K0 -> S0) + + # GEMM_QK10 (Q1 * K0 -> S1), K0 is ready in GEMM_QK00 + # 1. wait for Q1 + q1_handle = load_q_consumer.wait_and_advance() + tSrQ1 = tSrQ[None, None, None, q1_handle.index] + # 2. acquire empty S1 + s1_handle = mma_s1_producer.acquire_and_advance() + # 3. gemm + num_kphases = cute.size(tSrQ1, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_1 = (None, None, kphase_idx) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + qk_tiled_mma, + tStS1, + tSrQ1[kphase_coord_1], + tSrK0[kphase_coord_1], + tStS1, + ) + # 4. release S1 + s1_handle.commit() + # 5. release K0 + k_handle.release() + # End of GEMM (Q1 * K0 -> S1) + # Note: Q0 & Q1 are still needed in the seqlen_kv loop + # so we need to release them after the seqlen_kv loop + + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + v_handle = load_kv_consumer.wait_and_advance() + tOrVi = tOrV[None, None, None, v_handle.index] + # 2. acquire corrected O0_partial + # Note: acquire corr first to take it out of the critical + # path since softmax takes longer + o0_handle = mma_corr_producer.acquire_and_advance() + # 3. acquire P0 + # this acquire returns the ownership of all of S0 to the mma warp + # including the P0 part (inplaced in S0) + s0_handle = mma_s0_producer.acquire_and_advance() + # 4. gemm + num_kphases = cute.size(tOrP0, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_2 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + pv_tiled_mma, + tOtO0, + tOrP0[kphase_coord_2], + tOrVi[kphase_coord_2], + tOtO0, + ) + # 5. release accumulated O0_partial + o0_handle.commit() + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + + seqlen_kv_loop_steps = ( + self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + - 1 + ) + + # O1 hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + pv_whether_acc = False + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # GEMM_QK0i (Q0 * Ki -> S0) + # 1. wait for Ki + k_handle = load_kv_consumer.wait_and_advance() + tSrKi = tSrK[None, None, None, k_handle.index] + # 2. gemm + inner_num_kphases = cute.size(tSrQ0, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_3 = (None, None, kphase_idx) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + qk_tiled_mma, + tStS0, + tSrQ0[kphase_coord_3], + tSrKi[kphase_coord_3], + tStS0, + ) + # 3. release S0 + s0_handle.commit() + # End of GEMM_QK0i (Q0 * Ki -> S0) + + # GEMM_PV1(i-1) (P1 * V(i-1) -> O1_partial), V(i-1) is ready in GEMM_PV0(i-1) + # 1. acquire corrected O1_partial + o1_handle = mma_corr_producer.acquire_and_advance() + # 2. acquire P1 + s1_handle = mma_s1_producer.acquire_and_advance() + # 3. gemm + inner_num_kphases = cute.size(tOrP0, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_4 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) + cute.gemm( + pv_tiled_mma, + tOtO1, + tOrP1[kphase_coord_4], + tOrVi[kphase_coord_4], + tOtO1, + ) + pv_whether_acc = True + # 4. release accumulated O1_partial + o1_handle.commit() + # 5. release V(i-1) + v_handle.release() + # End of GEMM_PV1(i-1) (P1 * V(i-1) -> O1_partial) + + # GEMM_QK1i (Q1 * Ki -> S1), Q1 is ready in GEMM_QK10; Ki is ready in GEMM_QK0i + # 1. gemm + inner_num_kphases = cute.size(tSrQ1, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_5 = (None, None, kphase_idx) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + qk_tiled_mma, + tStS1, + tSrQ1[kphase_coord_5], + tSrKi[kphase_coord_5], + tStS1, + ) + s1_handle.commit() + # 2. release Ki + k_handle.release() + # End of GEMM_QK1i (Q1 * Ki -> S1) + + # GEMM_PV0i (P0 * Vi -> O0_partial) + # 1. wait for Vi + v_handle = load_kv_consumer.wait_and_advance() + tOrVi = tOrV[None, None, None, v_handle.index] + # 2. acquire corrected O0_partial + o0_handle = mma_corr_producer.acquire_and_advance() + # 3. acquire P0 + s0_handle = mma_s0_producer.acquire_and_advance() + # 4. gemm + inner_num_kphases = cute.size(tOrP0, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_6 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.gemm( + pv_tiled_mma, + tOtO0, + tOrP0[kphase_coord_6], + tOrVi[kphase_coord_6], + tOtO0, + ) + # 5. release accumulated O0_partial + o0_handle.commit() + # End of GEMM_PV0i (P0 * Vi -> O0_partial) + # End of seqlen_kv loop + + # release Q0 & Q1 + q0_handle.release() + q1_handle.release() + + # GEMM_PV1(i_end) (P1 * Vi_end -> O1) + # 1. acquire corrected O1_partial + o1_handle = mma_corr_producer.acquire_and_advance() + # 2. acquire P1 + s1_handle = mma_s1_producer.acquire_and_advance() + # 3. gemm + num_kphases = cute.size(tOrP1, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_7 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.gemm( + pv_tiled_mma, + tOtO1, + tOrP1[kphase_coord_7], + tOrVi[kphase_coord_7], + tOtO1, + ) + # 4. commit accumulated O1 + o1_handle.commit() + # 5. release Vi_end + v_handle.release() + # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) + + # Commit S0 and S1 + s0_handle.commit() + s1_handle.commit() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + + # dealloc tmem buffer + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + # Retrieving tmem ptr and make acc + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, + alignment=16, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.epilogue_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + continue_cond = False + cuseqlen_q = Int32(0) + seqlen_q = mQ_qdl.shape[0] + + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + if not continue_cond: + curr_block_coord_o = curr_block_coord + mO_qdl_ = mO_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + logical_offset_mO = ( + mO_qdl_.shape[0] - seqlen_q, + 0, + (0, cuseqlen_q + seqlen_q), + ) + mO_qdl_ = cute.domain_offset(logical_offset_mO, mO_qdl_) + curr_block_coord_o = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], 0), + ) + + o0_coord = 2 * curr_block_coord_o[0] + o1_coord = o0_coord + 1 + gO_qdl = cute.flat_divide( + mO_qdl_, cute.select(self.pv_mma_tiler, mode=[0, 1]) + ) + gO = gO_qdl[None, None, None, 0, curr_block_coord_o[2]] + tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( + tma_atom_o, + 0, + cute.make_layout(1), + cute.group_modes(sO, 0, 2), + cute.group_modes(gO, 0, 2), + ) + + # O0 O1 using the same pipeline + # wait from corr, issue tma store on smem + # O0 + # 1. wait for O0 final + o0_handle = corr_epi_consumer.wait_and_advance() + # 2. copy O0 to gmem + cute.copy(tma_atom_o, tOsO[None, 0], tOgO[None, o0_coord]) + cute.arch.cp_async_bulk_commit_group() + # O1 + # 1. wait for O1 final + o1_handle = corr_epi_consumer.wait_and_advance() + # 2. copy O1 to gmem + cute.copy(tma_atom_o, tOsO[None, 1], tOgO[None, o1_coord]) + cute.arch.cp_async_bulk_commit_group() + + # Ensure O0 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(1, read=True) + o0_handle.release() + # Ensure O1 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(0, read=True) + o1_handle.release() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax0 + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx < self.softmax1_warp_ids[0]: + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + + self.softmax( + stage=0, + seqlen_k=mK_kdl.shape[0], + cum_seqlen_q=cum_seqlen_q, + cum_seqlen_k=cum_seqlen_k, + scale_softmax_log2=scale_softmax_log2, + qk_thr_mma=qk_thr_mma, + tStS=tStS, + tStSi=tStS0, + mma_si_consumer=mma_s0_consumer, + si_corr_producer=s0_corr_producer, + s0_s1_sequence_consumer=s0_s1_sequence_consumer, + s0_s1_sequence_producer=s0_s1_sequence_producer, + tile_sched_params=tile_sched_params, + ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax1 + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx < self.correction_warp_ids[0] + and warp_idx >= self.softmax1_warp_ids[0] + ): + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + + self.softmax( + stage=1, + seqlen_k=mK_kdl.shape[0], + cum_seqlen_q=cum_seqlen_q, + cum_seqlen_k=cum_seqlen_k, + scale_softmax_log2=scale_softmax_log2, + qk_thr_mma=qk_thr_mma, + tStS=tStS, + tStSi=tStS1, + mma_si_consumer=mma_s1_consumer, + si_corr_producer=s1_corr_producer, + s0_s1_sequence_consumer=s0_s1_sequence_consumer, + s0_s1_sequence_producer=s0_s1_sequence_producer, + tile_sched_params=tile_sched_params, + ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) + + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr_mma.partition_C(cS) + + tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) + + tStS_vec0 = cute.make_tensor( + tStS.iterator + self.tmem_vec0_offset, tStS_vec_layout + ) + tStS_vec1 = cute.make_tensor( + tStS.iterator + self.tmem_vec1_offset, tStS_vec_layout + ) + + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tmem_load_v_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + + tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStS_vec0) + thread_idx = tidx % (self.threads_per_warp * len(self.correction_warp_ids)) + thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(thread_idx) + + tTMEM_LOAD_VECtS0 = thr_tmem_load_vec.partition_S(tStS_vec0) + tTMEM_LOAD_VECtS1 = thr_tmem_load_vec.partition_S(tStS_vec1) + tTMEM_LOAD_VECcS = thr_tmem_load_vec.partition_D(tScS_vec) + + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + seqlen_k = mK_kdl.shape[0] + continue_cond = False + + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + # Ignore first signal from softmax as no correction is required + vec0_handle = s0_corr_consumer.wait_and_advance() + vec0_handle.release() + vec1_handle = s1_corr_consumer.wait_and_advance() + seqlen_kv_loop_steps = ( + self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + - 1 + ) + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # wait for vec0 (row_wise current max & previous max) + vec0_handle = s0_corr_consumer.wait_and_advance() + tTMEM_LOAD_VECrS = cute.make_fragment( + tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype + ) + cute.copy( + tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS + ) + scale_ = scale_softmax_log2 * ( + tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] + ) + scale = cute.math.exp2(scale_, fastmath=True) + # wait for o0 + o0_handle = mma_corr_consumer.wait_and_advance() + self.correction_rescale(pv_thr_mma, tOtO0, scale) + # release vec1 & o0 + vec1_handle.release() + cute.arch.fence_view_async_tmem_store() + o0_handle.release() + + # wait for vec1 (row_wise current max & previous max) + vec1_handle = s1_corr_consumer.wait_and_advance() + cute.copy( + tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS + ) + scale_ = scale_softmax_log2 * ( + tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] + ) + scale = cute.math.exp2(scale_, fastmath=True) + o1_handle = mma_corr_consumer.wait_and_advance() + self.correction_rescale(pv_thr_mma, tOtO1, scale) + vec0_handle.release() + cute.arch.fence_view_async_tmem_store() + o1_handle.release() + # End of seqlen_corr_loop_steps + vec1_handle.release() + + # wait for vec0 (row_wise global sum) + vec0_handle = s0_corr_consumer.wait_and_advance() + tTMEM_LOAD_VECrS = cute.make_fragment( + tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype + ) + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) + cute.arch.fence_view_async_tmem_load() + vec0_handle.release() + # wait for o0 + o0_handle = mma_corr_consumer.wait_and_advance() + o0_final_handle = corr_epi_producer.acquire_and_advance() + self.correction_epilog( + pv_thr_mma, + tOtO0, + scale_output / tTMEM_LOAD_VECrS[0], + sO[None, None, 0], + ) + o0_handle.release() + o0_final_handle.commit() + + # wait for vec1 (row_wise global sum) + vec1_handle = s1_corr_consumer.wait_and_advance() + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS) + cute.arch.fence_view_async_tmem_load() + vec1_handle.release() + # wait for o1 + o1_handle = mma_corr_consumer.wait_and_advance() + o1_final_handle = corr_epi_producer.acquire_and_advance() + self.correction_epilog( + pv_thr_mma, + tOtO1, + scale_output / tTMEM_LOAD_VECrS[0], + sO[None, None, 1], + ) + o1_handle.release() + o1_final_handle.commit() + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + return + + @cute.jit + def softmax_step( + self, + stage: int, + need_apply_mask: bool, + iter_args: tuple, + value_args: tuple, + pipeline_args: tuple, + atom_args: tuple, + tensor_args: tuple, + ) -> Tuple[ + Float32, + Float32, + pipeline.PipelineProducer.ImmutableResourceHandle, + pipeline.PipelineConsumer, + pipeline.PipelineProducer, + pipeline.PipelineConsumer, + pipeline.PipelineProducer, + ]: + """Perform a single step of the softmax computation on a block of attention scores. + + This method processes one block of the attention matrix, computing numerically stable + softmax by first finding the row maximum, subtracting it from all elements, applying + exponential function, and then normalizing by the sum of exponentials. It also handles + optional masking of attention scores. + + The method involves several key operations: + 1. Loading attention scores from tensor memory + 2. Applying optional masking based on position + 3. Computing row-wise maximum values for numerical stability + 4. Transforming scores using exp2(x*scale - max*scale) + 5. Computing row sums for normalization + 6. Coordinating pipeline synchronization between different processing stages + + :param stage: Processing stage (0 for first half, 1 for second half) + :type stage: int + :param need_apply_mask: Whether to apply attention masking + :type need_apply_mask: bool + :param iter_args: Tuple containing the counting tensor, row_max, row_sum, and vector buffer's handle for current iteration + :type iter_args: tuple + :param value_args: Tuple containing seqlen_k and scale_softmax_log2 + :type value_args: tuple + :param pipeline_args: Tuple containing pipeline related arguments for MMA, correction, and sequence synchronization + :type pipeline_args: tuple + :param atom_args: Tuple containing mma & copy atoms + :type atom_args: tuple + :param tensor_args: Tuple containing softmax related tensors + :type tensor_args: tuple + :return: Updated state values (row_max, row_sum, and pipeline related arguments) + :rtype: tuple + """ + cS, row_max, row_sum, vec_i_handle = iter_args + seqlen_k, scale_softmax_log2 = value_args + ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = pipeline_args + ( + qk_thr_mma, + tiled_tmem_load, + tiled_tmem_store, + tiled_tmem_store_vec, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + ) = atom_args + ( + tTMEM_LOADtS, + tTMEM_STORE_VECtS, + tTMEM_STOREtS_x4, + ) = tensor_args + + tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + tScS = qk_thr_mma.partition_C(cS) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tScS_P_layout = cute.composition( + tScS.layout, cute.make_layout((128, tilePlikeFP32)) + ) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + tTMEM_LOADcS = thr_tmem_load.partition_D(tScS) + tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec) + tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P) + + # Wait for Si + si_handle = mma_si_consumer.wait_and_advance() + tTMEM_LOADrS = cute.make_fragment(tTMEM_LOADcS.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) + if need_apply_mask: + self.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, seqlen_k) + + old_row_max = row_max + row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0) + row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + row_max_safe = 0.0 + tTMEM_STORE_VECrS = cute.make_fragment( + tTMEM_STORE_VECcS.shape, self.qk_acc_dtype + ) + tTMEM_STORE_VECrS[0] = old_row_max + tTMEM_STORE_VECrS[1] = row_max_safe + cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS) + cute.arch.fence_view_async_tmem_store() + # Notify correction wg that row_max is ready + vec_i_handle.commit() + + tTMEM_STORErS_x4 = cute.make_fragment(tTMEM_STOREcS.shape, self.qk_acc_dtype) + tTMEM_STORErS_x4_e = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype), + tTMEM_LOADrS.layout, + ) + + scale = scale_softmax_log2 + minus_row_max_scale = (0.0 - row_max_safe) * scale + + # Sequence barrier wait + if cutlass.const_expr(stage == 0): + sequence_producer_handle = s0_s1_sequence_producer.acquire_and_advance() + else: + sequence_consumer_handle = s0_s1_sequence_consumer.wait_and_advance() + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + tTMEM_STORErS_x4_e_frg = cute.logical_divide( + tTMEM_STORErS_x4_e, cute.make_layout(frg_tile) + ) + for j in range(frg_cnt): + for k in range(0, cute.size(tTMEM_LOADrS_frg, mode=[0]), 2): + tTMEM_LOADrS_frg[k, j], tTMEM_LOADrS_frg[k + 1, j] = ( + cute.arch.fma_packed_f32x2( + (tTMEM_LOADrS_frg[k, j], tTMEM_LOADrS_frg[k + 1, j]), + (scale, scale), + (minus_row_max_scale, minus_row_max_scale), + ) + ) + tTMEM_LOADrS_frg[k, j] = cute.math.exp2( + tTMEM_LOADrS_frg[k, j], fastmath=True + ) + tTMEM_LOADrS_frg[k + 1, j] = cute.math.exp2( + tTMEM_LOADrS_frg[k + 1, j], fastmath=True + ) + s_vec = tTMEM_LOADrS_frg[None, j].load() + tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype)) + # Sequence barrier arrive + if cutlass.const_expr(stage == 0): + sequence_producer_handle.commit() + else: + sequence_consumer_handle.release() + cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4) + cute.arch.fence_view_async_tmem_store() + # Notify tensor core warp that softmax(S->P) is ready + si_handle.release() + + vec_i_handle = si_corr_producer.acquire_and_advance() + acc_scale_ = scale * (old_row_max - row_max_safe) + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) * 0.5 + row_sum *= acc_scale + local_row_sum_0 = (row_sum, row_sum) + local_row_sum_1 = (0.0, 0.0) + local_row_sum_2 = (0.0, 0.0) + local_row_sum_3 = (0.0, 0.0) + + reduction_unroll = 4 + frg_tile = cute.size(tTMEM_LOADrS) // reduction_unroll + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + + for j in cutlass.range_constexpr(0, cute.size(tTMEM_LOADrS_frg, mode=[0]), 2): + local_row_sum_0 = cute.arch.add_packed_f32x2( + local_row_sum_0, (tTMEM_LOADrS_frg[j, 0], tTMEM_LOADrS_frg[j + 1, 0]) + ) + local_row_sum_1 = cute.arch.add_packed_f32x2( + local_row_sum_1, (tTMEM_LOADrS_frg[j, 1], tTMEM_LOADrS_frg[j + 1, 1]) + ) + local_row_sum_2 = cute.arch.add_packed_f32x2( + local_row_sum_2, (tTMEM_LOADrS_frg[j, 2], tTMEM_LOADrS_frg[j + 1, 2]) + ) + local_row_sum_3 = cute.arch.add_packed_f32x2( + local_row_sum_3, (tTMEM_LOADrS_frg[j, 3], tTMEM_LOADrS_frg[j + 1, 3]) + ) + + local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_1) + local_row_sum_2 = cute.arch.add_packed_f32x2(local_row_sum_2, local_row_sum_3) + local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_2) + row_sum = local_row_sum_0[0] + local_row_sum_0[1] + + return ( + row_max, + row_sum, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) + + # For both softmax0 and softmax1 warp group + @cute.jit + def softmax( + self, + stage: int, + seqlen_k: Int32, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + qk_thr_mma: cute.core.ThrMma, + tStS: cute.Tensor, + tStSi: cute.Tensor, + mma_si_consumer: pipeline.PipelineConsumer, + si_corr_producer: pipeline.PipelineProducer, + s0_s1_sequence_consumer: pipeline.PipelineConsumer, + s0_s1_sequence_producer: pipeline.PipelineProducer, + tile_sched_params: FmhaStaticTileSchedulerParams, + ): + """Compute softmax on attention scores from QK matrix multiplication. + + This method handles the softmax computation for either the first or second half of the + attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum + and sum values needed for stable softmax computation, applies optional masking, and + transforms raw attention scores into probability distributions. + + The implementation uses specialized memory access patterns and efficient math operations + for computing exp(x) using exp2 functions. It also coordinates pipeline + synchronization between MMA, correction, and sequence processing stages. + + :param stage: Processing stage (0 for first half, 1 for second half of attention matrix) + :type stage: int + :param scale_softmax_log2: Log2 scale factor for softmax operation + :type scale_softmax_log2: Float32 + :param qk_thr_mma: Thread MMA operation for QK matrix multiplication + :type qk_thr_mma: cute.core.ThrMma + :param tStS: Shared tensor for softmax input/output + :type tStS: cute.Tensor + :param tStSi: Input tensor containing attention scores + :type tStSi: cute.Tensor + :param mma_si_pipeline: Pipeline for synchronizing with MMA operations + :type mma_si_pipeline: pipeline.PipelineAsync + :param si_corr_pipeline: Pipeline for synchronizing with correction operations + :type si_corr_pipeline: pipeline.PipelineAsync + :param s0_s1_sequence_pipeline: Pipeline for synchronizing between stage 0 and 1 + :type s0_s1_sequence_pipeline: pipeline.PipelineAsync + :param tile_sched_params: Parameters for tile scheduling + :type tile_sched_params: FmhaStaticTileSchedulerParams + """ + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % ( + self.threads_per_warp + * ( + len(self.softmax0_warp_ids) + if stage == 0 + else len(self.softmax1_warp_ids) + ) + ) + + cS_base = cute.make_identity_tensor( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + tilePlikeFP32 = self.qk_mma_tiler[1] // 32 * self.o_dtype.width + tScS = qk_thr_mma.partition_C(cS_base) + tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) + tmem_vec_offset = self.tmem_vec0_offset if stage == 0 else self.tmem_vec1_offset + tStS_vec = cute.make_tensor(tStS.iterator + tmem_vec_offset, tStS_vec_layout) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + tStS_P_layout = cute.composition( + tStS.layout, cute.make_layout((128, tilePlikeFP32)) + ) + tmem_p_offset = self.tmem_p0_offset if stage == 0 else self.tmem_p1_offset + tStS_P = cute.make_tensor(tStS.iterator + tmem_p_offset, tStS_P_layout) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype, + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi) + thread_idx = tidx % ( + self.threads_per_warp + * ( + len(self.softmax0_warp_ids) + if stage == 0 + else len(self.softmax1_warp_ids) + ) + ) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + tTMEM_LOADtS = thr_tmem_load.partition_S(tStSi) + tmem_store_vec_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tiled_tmem_store_vec = tcgen05.make_tmem_copy(tmem_store_vec_atom, tStS_vec) + thr_tmem_store_vec = tiled_tmem_store_vec.get_slice(thread_idx) + tTMEM_STORE_VECtS = thr_tmem_store_vec.partition_D(tStS_vec) + tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype, + ) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P) + thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) + tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P) + + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + seqlen_k_ = seqlen_k + continue_cond = False + + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k_ = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + row_max = -Float32.inf + row_sum = 0.0 + value_args = (seqlen_k_, scale_softmax_log2) + atom_args = ( + qk_thr_mma, + tiled_tmem_load, + tiled_tmem_store, + tiled_tmem_store_vec, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + ) + tensor_args = ( + tTMEM_LOADtS, + tTMEM_STORE_VECtS, + tTMEM_STOREtS_x4, + ) + + logical_offset = ( + curr_block_coord[0] * self.cta_tiler[0] + + stage * self.qk_mma_tiler[0], + 0, + ) + cS = cute.domain_offset(logical_offset, cS_base) + vec_i_handle = si_corr_producer.acquire_and_advance() + unmask_count = self.get_unmasked_trip_count( + curr_block_coord, + self.cta_tiler, + seqlen_k_, + ) + for i in cutlass.range(0, unmask_count, 1, unroll=1): + cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + iter_args = (cS_iter, row_max, row_sum, vec_i_handle) + pipeline_args = ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) + ( + row_max, + row_sum, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = self.softmax_step( + stage, + False, + iter_args, + value_args, + pipeline_args, + atom_args, + tensor_args, + ) + mask_count = self.get_masked_trip_count( + curr_block_coord, + self.cta_tiler, + seqlen_k_, + ) + + for i in cutlass.range( + unmask_count, unmask_count + mask_count, 1, unroll=1 + ): + cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + iter_args = (cS_iter, row_max, row_sum, vec_i_handle) + pipeline_args = ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) + ( + row_max, + row_sum, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = self.softmax_step( + stage, + True, + iter_args, + value_args, + pipeline_args, + atom_args, + tensor_args, + ) + si_handle = mma_si_consumer.wait_and_advance() + tTMEM_STORE_VECrS = cute.make_fragment( + tTMEM_STORE_VECcS.shape, self.qk_acc_dtype + ) + tTMEM_STORE_VECrS[0] = row_sum + tTMEM_STORE_VECrS[1] = row_max + cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS) + cute.arch.fence_view_async_tmem_store() + vec_i_handle.commit() + si_corr_producer.acquire() + # Empty step to sync against pipe s + si_handle.release() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def correction_rescale( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + scale: Float32, + ): + """Rescale intermediate attention results based on softmax normalization factor. + + This method performs a crucial correction step in the attention computation pipeline. + When processing attention in blocks, the softmax normalization factors may change + as new blocks are processed. This method rescales previously computed partial + output values to account for updated normalization factors. + + The implementation uses efficient tensor memory operations to: + 1. Load existing partial attention output from tensor memory + 2. Apply the scaling factor to all elements + 3. Store the rescaled results back to tensor memory + + :param thr_mma: Thread MMA operation for the computation + :type thr_mma: cute.core.ThrMma + :param tOtO: Tensor representing partial attention output to be rescaled + :type tOtO: cute.Tensor + :param scale: Scaling factor to apply to the partial results + :type scale: Float32 + """ + pv_tiled_mma_shape = ( + self.pv_mma_tiler[0], + self.pv_mma_tiler[1], + ) + cO = cute.make_identity_tensor(pv_tiled_mma_shape) + tOcO = thr_mma.partition_C(cO) + + corr_tile_size = 16 # tuneable parameter + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + + tOtO_i_layout = cute.composition( + tOtO.layout, cute.make_layout((128, corr_tile_size)) + ) + tOcO_i_layout = cute.composition( + tOcO.layout, cute.make_layout((128, corr_tile_size)) + ) + + tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i) + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.correction_warp_ids)) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) + + tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i) + tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i) + + tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i) + + tTMrO = cute.make_fragment( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.pv_acc_dtype + ) + for i in range(self.cta_tiler[2] // corr_tile_size): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout + ) + + cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i) + for j in range(0, cute.size(tTMrO_i), 2): + tTMrO_i[j], tTMrO_i[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO_i[j], tTMrO_i[j + 1]), + (scale, scale), + ) + cute.copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i) + + @cute.jit + def correction_epilog( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + scale: Float32, + sO: cute.Tensor, + ): + """Apply final scaling and transformation to attention output before writing to global memory. + + This correction_epilog function handles the final processing step for attention output values. + It applies a scaling factor to the accumulated attention results and prepares the + data for efficient transfer back to global memory. + + The method performs: + 1. Loading of accumulated attention results from tensor memory + 2. Application of the final output scaling factor + 3. Type conversion if necessary (typically from higher precision accumulator to output precision) + 4. Reorganization of data for optimal memory access patterns + 5. Preparation for efficient TMA store operations + + :param thr_mma: Thread MMA operation for the computation + :type thr_mma: cute.core.ThrMma + :param tOtO: Tensor containing accumulated attention output + :type tOtO: cute.Tensor + :param scale: Final scaling factor to apply to the output + :type scale: Float32 + :param sO: Shared memory tensor for the final output + :type sO: cute.Tensor + """ + + pv_tiled_mma_shape = ( + self.pv_mma_tiler[0], + self.pv_mma_tiler[1], + ) + cO = cute.make_identity_tensor(pv_tiled_mma_shape) + + corr_tile_size = 32 * 8 // self.o_dtype.width + tOsO = thr_mma.partition_C(sO) + tOcO = thr_mma.partition_C(cO) + + tOtO_i = cute.logical_divide(tOtO, cute.make_layout((128, corr_tile_size))) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size))) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_tile_size))) + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.correction_warp_ids)) + + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_copy_atom = sm100_utils.get_tmem_load_op( + self.pv_mma_tiler, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=False, + ) + + tiled_tmem_load = tcgen05.make_tmem_copy( + tmem_copy_atom, tOtO_i[(None, None), 0] + ) + + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + smem_copy_atom = sm100_utils.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load + ) + tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) + + tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) + tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tTMEM_LOADoO = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) + + for i in range(self.cta_tiler[2] // corr_tile_size): + tTMEM_LOADtO_i = tTMEM_LOADtO[None, 0, 0, i] + tTMEM_LOADsO_i = tTMEM_LOADsO[None, 0, 0, i] + tTMrO = cute.make_fragment( + tTMEM_LOADoO[None, 0, 0, i].shape, self.pv_acc_dtype + ) + cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO) + for j in range(0, cute.size(tTMrO), 2): + tTMrO[j], tTMrO[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO[j], tTMrO[j + 1]), + (scale, scale), + ) + tSMrO = cute.make_fragment(tTMrO.shape, self.o_dtype) + o_vec = tTMrO.load() + tSMrO.store(o_vec.to(self.o_dtype)) + cute.copy(tiled_smem_store, tSMrO, tTMEM_LOADsO_i) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + def get_trip_count( + self, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + ) -> Int32: + result = 0 + if ( + self.mask_type == MaskType.NO_MASK + or self.mask_type == MaskType.RESIDUAL_MASK + ): + result = cute.ceil_div(seqlen_k, tile_shape[1]) + elif self.mask_type == MaskType.CAUSAL_MASK: + max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) + max_blocks_q = cute.ceil_div( + (blk_coord[0] + 1) * tile_shape[0], tile_shape[1] + ) + result = cutlass.min(max_blocks_k, max_blocks_q) + return result + + @cute.jit + def get_masked_trip_count( + self, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + ) -> Int32: + result = 0 + if self.mask_type == MaskType.NO_MASK: + result = 0 + elif self.mask_type == MaskType.RESIDUAL_MASK: + if seqlen_k % tile_shape[1] != 0: + result = 1 + else: + result = 0 + elif self.mask_type == MaskType.CAUSAL_MASK: + trip_count = self.get_trip_count(blk_coord, tile_shape, seqlen_k) + result = cutlass.min( + trip_count, + cute.ceil_div(tile_shape[0], tile_shape[1]), + ) + return result + + @cute.jit + def get_unmasked_trip_count( + self, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + ) -> Int32: + result = 0 + if self.mask_type == MaskType.NO_MASK: + result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) + elif self.mask_type == MaskType.RESIDUAL_MASK: + if seqlen_k % tile_shape[1] != 0: + result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) - 1 + else: + result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) + elif self.mask_type == MaskType.CAUSAL_MASK: + result = self.get_trip_count( + blk_coord, tile_shape, seqlen_k + ) - self.get_masked_trip_count(blk_coord, tile_shape, seqlen_k) + return result + + @cute.jit + def apply_mask( + self, + acc_qk: cute.Tensor, + index_qk: cute.Tensor, + seqlen_k: Int32, + ): + if self.mask_type == MaskType.RESIDUAL_MASK: + for i in range(cute.size(acc_qk)): + pos = index_qk[i] + if pos[1] >= seqlen_k: + acc_qk[i] = -Float32.inf + elif self.mask_type == MaskType.CAUSAL_MASK: + for i in range(cute.size(acc_qk)): + pos = index_qk[i] + if pos[0] < pos[1] or pos[1] >= seqlen_k: + acc_qk[i] = -Float32.inf + + @staticmethod + def _compute_grid( + o_shape: cute.Shape, + cta_tiler: Tuple[int, int, int], + is_persistent: bool, + ) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: + tile_sched_params = create_fmha_static_tile_scheduler_params( + is_persistent, + ( + cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), + cute.size(o_shape[2][0]), + cute.size(o_shape[2][1]), + ), + ) + grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) + return tile_sched_params, grid + + +def run( + q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int], + k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int], + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + qk_acc_dtype: Type[cutlass.Numeric], + pv_acc_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + is_persistent: bool, + is_causal: bool, + scale_q: float, + scale_k: float, + scale_v: float, + inv_scale_o: float, + scale_softmax: float, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + """Execute Fused Multi-Head Attention (FMHA) on Blackwell architecture and validate results. + + This function creates random input tensors for query, key, and value, then performs the + complete FMHA computation pipeline. It supports configurable data types, tiling parameters, + and various attention masking options. Results can be validated against a PyTorch reference + implementation or run multiple times for performance measurement. + + The implementation leverages specialized tensor memory operations and efficient math + operations optimized for Blackwell architecture, including pipelined computation stages + for maximum throughput. + + :param q_shape: Query tensor shape (B, S_q, H, D) where B=batch size, S_q=query sequence length, + H=number of heads, D=head dimension. + If S_q is a tuple, it is the variable sequence length. + :type q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int] + :param k_shape: Key tensor shape (B, S_k, H_k, D) where B=batch size, S_k=key sequence length, + H_k=number of key heads (H must be divisible by H_k), D=head dimension. + If S_k is a tuple, it is the variable sequence length. + :type k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int] + :param in_dtype: Input data type for query, key and value tensors + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: Output data type for attention output + :type out_dtype: Type[cutlass.Numeric] + :param qk_acc_dtype: Accumulator data type for query-key matrix multiplication + :type qk_acc_dtype: Type[cutlass.Numeric] + :param pv_acc_dtype: Accumulator data type for probability-value matrix multiplication + :type pv_acc_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: Matrix multiply accumulate tile shape (M, N) + :type mma_tiler_mn: Tuple[int, int] + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_causal: Whether to apply causal masking + :type is_causal: bool + :param scale_q: Scaling factor for query tensor + :type scale_q: float + :param scale_k: Scaling factor for key tensor + :type scale_k: float + :param scale_v: Scaling factor for value tensor + :type scale_v: float + :param inv_scale_o: Inverse scaling factor for output tensor + :type inv_scale_o: float + :param scale_softmax: Attention score scaling factor (defaults to 1/sqrt(D) if set to 0) + :type scale_softmax: float + :param tolerance: Maximum acceptable error for validation + :type tolerance: float + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations to run for performance testing + :type iterations: int + :param skip_ref_check: Skip validation against reference implementation + :type skip_ref_check: bool + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache + :type use_cold_l2: bool + + :raises ValueError: If input shapes are incompatible or head dimension is unsupported + :raises RuntimeError: If GPU is unavailable for computation + :return: Execution time of the FMHA kernel in microseconds + :rtype: float + """ + + print(f"Running Blackwell SM100 FMHA test with:") + print(f" q_shape: {q_shape}") + print(f" k_shape: {k_shape}") + print(f" in_dtype: {in_dtype}") + print(f" out_dtype: {out_dtype}") + print(f" qk_acc_dtype: {qk_acc_dtype}") + print(f" pv_acc_dtype: {pv_acc_dtype}") + print(f" mma_tiler_mn: {mma_tiler_mn}") + print(f" is_persistent: {is_persistent}") + print(f" is_causal: {is_causal}") + print(f" scale_q: {scale_q}") + print(f" scale_k: {scale_k}") + print(f" scale_v: {scale_v}") + print(f" inv_scale_o: {inv_scale_o}") + print(f" scale_softmax: {scale_softmax}") + print(f" tolerance: {tolerance}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + + # Unpack parameters + b, s_q, h_q, d = q_shape + b_, s_k, h_k, d_ = k_shape + + if b != b_: + raise ValueError("q & k must have the same batch size") + + if d != d_: + raise ValueError("q & k must have the same head dimension") + + if d not in {32, 64, 128}: + raise ValueError("head dimension must be 32, 64, or 128") + + if h_q % h_k != 0: + raise ValueError("h_q must be divisible by h_k") + + if isinstance(s_q, tuple) and len(s_q) != b: + raise ValueError("variable_seqlen s_q must have the length of batch size") + if isinstance(s_k, tuple) and len(s_k) != b: + raise ValueError("variable_seqlen s_k must have the length of batch size") + + if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}: + raise ValueError("in_dtype must be Float8E4M3FN or Float16") + + if out_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}: + raise ValueError("out_dtype must be Float8E4M3FN or Float16") + + if qk_acc_dtype not in {Float32}: + raise ValueError("qk_acc_dtype must be Float32") + + if pv_acc_dtype not in {Float32}: + raise ValueError("pv_acc_dtype must be Float32") + + if iterations < 1: + raise ValueError("iterations must be at least 1") + + h_r = h_q // h_k + + # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + def create_cumulative_sequence_lengths(s): + s_cumsum = [0] + for i in range(len(s)): + s_cumsum.append(s_cumsum[-1] + s[i]) + + s_cumsum_cute_tensor, s_cumsum_torch_tensor = cutlass_torch.cute_tensor_like( + torch.tensor(s_cumsum, dtype=torch.int32), + Int32, + is_dynamic_layout=True, + assumed_align=16, + ) + + return s_cumsum_cute_tensor, s_cumsum_torch_tensor + + cum_seqlen_q, cum_seqlen_q_torch = ( + create_cumulative_sequence_lengths(s_q) + if isinstance(s_q, tuple) + else (None, None) + ) + cum_seqlen_k, cum_seqlen_k_torch = ( + create_cumulative_sequence_lengths(s_k) + if isinstance(s_k, tuple) + else (None, None) + ) + + def create_and_pad_tensor( + shape, padding, dtype, s_cumsum=None, is_dynamic_layout=True + ): + # (b, s, h, d) + shape_ = tuple(map(lambda x, y: x + y, shape, padding)) + if s_cumsum is not None: + if shape_[0] != 1 or padding[0] != 0: + raise ValueError("Invalid tensor creation for variable sequence length") + # (s_total + padding, h, d) + shape_ = shape_[1:] + padding = padding[1:] + + # Create f32 torch tensor (cpu) + f32_torch_tensor_full = cutlass_torch.create_and_permute_torch_tensor( + shape_, + torch.float32, + permute_order=None, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=-2 if dtype.is_float or dtype.signed else 0, max_val=2 + ), + ) + # Create dtype cute & torch tensor (gpu) + _, torch_tensor_full = cutlass_torch.cute_tensor_like( + f32_torch_tensor_full, + dtype, + is_dynamic_layout, + assumed_align=16, + ) + + # Offset the tensor + slices = tuple(slice(s, e) for s, e in zip(padding, shape_)) + torch_tensor = torch_tensor_full[slices].detach() + f32_torch_tensor = f32_torch_tensor_full[slices].detach() + torch_tensor._keep_alive = torch_tensor_full + f32_torch_tensor._keep_alive = f32_torch_tensor_full + + # Create dtype cute tensor with offset (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + + # From ragged to jagged + if s_cumsum is not None: + torch_tensor = torch.nested.nested_tensor_from_jagged( + values=torch_tensor, offsets=s_cumsum + ) + f32_torch_tensor = torch.nested.nested_tensor_from_jagged( + values=f32_torch_tensor, offsets=s_cumsum.cpu() + ) + + return ( + f32_torch_tensor, + cute_tensor, + torch_tensor, + ) + + qo_shape = (b, s_q, h_r * h_k, d) + kv_shape = (b, s_k, h_k, d) + qo_padding = (0, 0, 0, 0, 0) + kv_padding = (0, 0, 0, 0, 0) + + if isinstance(s_q, tuple): + qo_shape = (1, sum(s_q), h_r * h_k, d) + qo_padding = (0, max(s_q), 0, 0, 0) + + if isinstance(s_k, tuple): + kv_shape = (1, sum(s_k), h_k, d) + kv_padding = (0, max(s_k), 0, 0, 0) + + q_ref, q_tensor, q_torch = create_and_pad_tensor( + qo_shape, + qo_padding, + in_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, + ) + k_ref, k_tensor, k_torch = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, + ) + v_ref, v_tensor, v_torch = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, + ) + _, o_tensor, o_torch = create_and_pad_tensor( + qo_shape, + qo_padding, + out_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, + ) + + mma_tiler = (*mma_tiler_mn, d) + + mask_type = MaskType.NO_MASK + if is_causal: + mask_type = MaskType.CAUSAL_MASK + else: + if isinstance(s_k, tuple): + for i in range(len(s_k)): + if s_k[i] % mma_tiler_mn[1] != 0: + mask_type = MaskType.RESIDUAL_MASK + else: + if s_k % mma_tiler_mn[1] != 0: + mask_type = MaskType.RESIDUAL_MASK + + fmha = BlackwellFusedMultiHeadAttentionForward( + qk_acc_dtype, + pv_acc_dtype, + mma_tiler, + is_persistent, + mask_type, + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + if scale_softmax == 0.0: # default to 1/sqrt(d) + scale_softmax = 1.0 / math.sqrt(d) + log2_e = math.log2( + math.exp(1.0) + ) # gpu uses exp2 for perf concerns, we need an extra factor 'log2_e' here + + scale_softmax = scale_q * scale_k * scale_softmax + scale_softmax_log2 = scale_softmax * log2_e + scale_output = scale_v * inv_scale_o + + problem_size = ( + b, + max(s_q) if isinstance(s_q, tuple) else s_q, + max(s_k) if isinstance(s_k, tuple) else s_k, + h_q, + h_k, + d, + ) + + print("Compiling kernel with cute.compile ...") + start_time = time.time() + # compile fmha kernel + compiled_fmha = cute.compile( + fmha, + q_tensor.iterator, + k_tensor.iterator, + v_tensor.iterator, + o_tensor.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + current_stream, + ) + compilation_time = time.time() - start_time + print(f"Compilation time: {compilation_time:.4f} seconds") + + def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False): + h_q = q.shape[2] + h_k = k.shape[2] + + # as we initialize q, k, v with shape (b, s, h, d) and SDPA of torch needs them to be (b, h, s, d) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # For the situation that torch has not supported, we need to handle it manually + situation1 = (h_q != h_k or is_causal) and (q.is_nested or k.is_nested) + situation2 = (q.is_nested and not k.is_nested) or ( + not q.is_nested and k.is_nested + ) + if situation1 or situation2: + # Once torch supports the situation, we can remove this fallback + batch_size = q.size(0) + ref_list = [] + for batch_idx in range(batch_size): + q_i = q[batch_idx] + k_i = k[batch_idx] + v_i = v[batch_idx] + + ref_i = F.scaled_dot_product_attention( + q_i, + k_i, + v_i, + attn_mask=None, + dropout_p=0.0, + scale=scale_softmax, + is_causal=is_causal, + enable_gqa=(h_q != h_k), + ) + ref_i = ref_i.transpose(0, 1) * scale_output + ref_list.append(ref_i) + if q.is_nested: + ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged) + else: + ref = torch.stack(ref_list) + else: + ref = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + scale=scale_softmax, + is_causal=is_causal, + enable_gqa=(h_q != h_k), + ) + ref = ref.transpose(1, 2) * scale_output + return ref + + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_fmha( + q_tensor.iterator, + k_tensor.iterator, + v_tensor.iterator, + o_tensor.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + current_stream, + ) + print("Verifying results...") + o_ref = run_torch_fmha( + q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal + ) + + if o_ref.is_nested: + o_ref = o_ref.values() + + if o_torch.is_nested: + o_torch = o_torch.values() + + # convert o back to f32 for comparison + o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( + torch.empty(*o_torch.shape, dtype=torch.float32), + Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + cute.testing.convert(o_tensor, o_fp32) + o_result = o_fp32_torch.cpu() + + if out_dtype.is_float and out_dtype.width <= 8: + ref_narrow_precision, _ = cutlass_torch.cute_tensor_like( + torch.empty(*o_ref.shape, dtype=torch.uint8), + out_dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + + ref_o_f32, ref_o_f32_torch = cutlass_torch.cute_tensor_like( + o_ref, + cutlass.Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + + # convert ref : f32 -> fp4/fp8 -> f32 + cute.testing.convert(ref_o_f32, ref_narrow_precision) + cute.testing.convert(ref_narrow_precision, ref_o_f32) + + o_ref = ref_o_f32_torch.cpu() + + # override tolerance + tolerance = 0.13 + + # Assert close results + torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + _, q_tensor_workspace, _ = create_and_pad_tensor( + qo_shape, + qo_padding, + in_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, + ) + _, k_tensor_workspace, _ = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, + ) + _, v_tensor_workspace, _ = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, + ) + _, o_tensor_workspace, _ = create_and_pad_tensor( + qo_shape, + qo_padding, + out_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, + ) + return testing.JitArguments( + q_tensor_workspace.iterator, + k_tensor_workspace.iterator, + v_tensor_workspace.iterator, + o_tensor_workspace.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch + k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch + v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch + o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch + one_workspace_bytes = ( + q_torch_effective.numel() * q_torch_effective.element_size() + + k_torch_effective.numel() * k_torch_effective.element_size() + + v_torch_effective.numel() * v_torch_effective.element_size() + + o_torch_effective.numel() * o_torch_effective.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_fmha, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str): + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_nested_comma_separated_ints(s: str): + try: + s = s.strip() + if "(" not in s: + return tuple(int(x.strip()) for x in s.split(",")) + + start = s.find("(") + end = s.find(")") + if start == -1 or end == -1: + raise ValueError("Mismatched parentheses") + + before = s[:start].strip().rstrip(",") + middle = s[start + 1 : end].strip() + after = s[end + 1 :].strip().lstrip(",") + + result = [] + if before: + result.extend(int(x.strip()) for x in before.split(",")) + + if middle: + nested_tuple = tuple(int(x.strip()) for x in middle.split(",")) + result.append(nested_tuple) + + if after: + result.extend(int(x.strip()) for x in after.split(",")) + + return tuple(result) + + except ValueError as e: + if str(e) == "Mismatched parentheses": + raise argparse.ArgumentTypeError("Mismatched parentheses in input") + else: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers with optional parentheses for nested tuple." + ) + + parser = argparse.ArgumentParser(description="Example of FMHA on Blackwell.") + + parser.add_argument( + "--in_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + help="Input data type", + ) + + parser.add_argument( + "--out_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + help="Output data type", + ) + + parser.add_argument( + "--qk_acc_dtype", + type=cutlass.dtype, + default=Float32, + help="QK accumulator data type", + ) + + parser.add_argument( + "--pv_acc_dtype", + type=cutlass.dtype, + default=Float32, + help="PV accumulator data type", + ) + + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="MMA tile shape (M, N)", + ) + + parser.add_argument( + "--is_persistent", + action="store_true", + help="Is persistent", + ) + + parser.add_argument( + "--is_causal", + action="store_true", + help="Whether to use casual mask", + ) + + parser.add_argument( + "--q_shape", + type=parse_nested_comma_separated_ints, + default=(1, 256, 8, 128), + help="Shape of Q (B, S_q, H, D)", + ) + + parser.add_argument( + "--k_shape", + type=parse_nested_comma_separated_ints, + default=(1, 256, 8, 128), + help="Shape of K (B, S_k, H_k, D)", + ) + + parser.add_argument( + "--scale_q", + type=float, + default=1.0, + help="Scaling factors to dequantize Q", + ) + + parser.add_argument( + "--scale_k", + type=float, + default=1.0, + help="Scaling factors to dequantize K", + ) + + parser.add_argument( + "--scale_v", + type=float, + default=1.0, + help="Scaling factors to dequantize V", + ) + + parser.add_argument( + "--inv_scale_o", + type=float, + default=1.0, + help="Scaling factor to quantize O", + ) + + parser.add_argument( + "--scale_softmax", + type=float, + default=0.0, + help="Scaling factor to scale S (i.e. Q*K); if zero, defaults to 1/sqrt(D)", + ) + + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + + parser.add_argument( + "--warmup_iterations", + type=int, + default=0, + help="Number of iterations for warmup", + ) + + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations after warmup", + ) + + parser.add_argument( + "--skip_ref_check", + action="store_true", + help="Skip reference check", + ) + + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.q_shape) != 4: + parser.error("--q_shape must contain exactly 4 values") + + if len(args.k_shape) != 4: + parser.error("--k_shape must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + run( + args.q_shape, + args.k_shape, + args.in_dtype, + args.out_dtype, + args.qk_acc_dtype, + args.pv_acc_dtype, + args.mma_tiler_mn, + args.is_persistent, + args.is_causal, + args.scale_q, + args.scale_k, + args.scale_v, + args.inv_scale_o, + args.scale_softmax, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..c279f7588fa697c516c5acae340c9fa735437ba7 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py @@ -0,0 +1,2445 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import functools +from typing import List, Type, Union +from inspect import isclass + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.utils as utils +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.torch as cutlass_torch + +""" +A grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of grouped GEMM using a TMA plus Blackwell SM100 TensorCore +warp-specialized persistent kernel. +The grouped GEMM workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices +in global memory are passed to the kernel in an array (also held in global memory). Similarly, problem shapes and +strides are also stored in arrays in GMEM. + +This differs from "Batched Array" GEMM since the size of each GEMM problem in the grouped GEMM concept may be distinct. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/grouped_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \ + --num_groups 4 --tensormap_update_mode SMEM + +The above example command makes 4 groups of different m, n, k sizes. The Blackwell tcgen05 MMA tile shape +is specified as (128, 64) and the cluster shape is (1,1). The input, mma accumulator and output data type +are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/grouped_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \ + --num_groups 4 --tensormap_update_mode SMEM \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + +There are some constrains for this example. Besides the constrains from the Balckwell dense GEMM persistent example, +there are also the following constrains: +* Only fp16 and bf16 data types are supported as inputs. +* Output data types could be fp16, bf16 or fp32. +* The contiguous dimension of each tensor must be at least 16 bytes aligned. +* The l mode(aka, batch size) for each group must be 1. +* The majorness for A, B and C must be the same across all groups. +""" + + +class GroupedGemmKernel: + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + tensormap_update_mode: utils.TensorMapUpdateMode = utils.TensorMapUpdateMode.SMEM, + ): + """Initializes the configuration for a Blackwell grouped GEMM kernel. + + Besides configurations for dense persistent GEMM, there is an extra config specific to grouped GEMM: + + Tensormap Update Mode: + - tensormap_update_mode: Specifies whether the tensormap is + updated in global memory(GMEM) or shared memory(SMEM). + The 2 modes are functionally equivalent and the difference are: + - We buffer 3 tensormaps in SMEM for A, B, and C tensors (each TMA descriptor takes 128B) when TMA updates performed on SMEM. + - Performance varies between modes depending on problem size; optimal choice differs across workloads. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param mma_tiler_mn: tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: tuple[int, int] + :param cluster_shape_mn: tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: tuple[int, int] + :param tensormap_update_mode: Mode for updating the tensormap (GMEM or SMEM), defaults to SMEM. + :type tensormap_update_mode: utils.TensorMapUpdateMode, optional + """ + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.tensormap_update_mode = tensormap_update_mode + # Delegate tensormap ab initialization to MMA warp when SMEM mode is used for better latency hiding + self.delegate_tensormap_ab_init = ( + tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ) + + self.num_mcast_ctas_a = 1 + self.num_mcast_ctas_b = 1 + self.is_a_mcast = False + self.is_b_mcast = False + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier id for cta sync, epilog sync, tmem ptr sync and tensormap update sync + self.cta_sync_bar_id = 0 + self.epilog_sync_bar_id = 1 + self.tmem_ptr_sync_bar_id = 2 + # Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion + self.tensormap_ab_init_bar_id = 4 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + self.num_tma_load_bytes = 0 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + Most of the implementation follows standard dense GEMM patterns, + with the key difference being additional consideration for SMEM + buffer needed for tensormap updates. + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cluster_tile_shape_mnk = tuple( + x * y for x, y in zip(self.cta_tile_shape_mnk, (*self.cluster_shape_mn, 1)) + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_epi_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + ) + + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_epi_stage, + ) + + tensor_smem_bytes = self._get_tensor_smem_bytes( + self.a_smem_layout_staged, + self.a_dtype, + self.b_smem_layout_staged, + self.b_dtype, + self.epi_smem_layout_staged, + self.c_dtype, + ) + mbar_smem_bytes = self._get_mbar_smem_bytes( + num_acc_stage=self.num_acc_stage, + num_ab_stage=self.num_ab_stage, + num_epi_stage=self.num_epi_stage, + ) + tensormap_smem_bytes = self._get_tensormap_smem_bytes( + self.tensormap_update_mode + ) + if ( + mbar_smem_bytes + + tensormap_smem_bytes + + GroupedGemmKernel.tensor_memory_management_bytes + > self.reserved_smem_bytes + ): + raise ValueError( + f"smem consumption for mbar and tensormap {mbar_smem_bytes + tensormap_smem_bytes} exceeds the " + f"reserved smem bytes {self.reserved_smem_bytes}" + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler, self.num_acc_stage + ) + + @cute.jit + def __call__( + self, + initial_a: cute.Tensor, + initial_b: cute.Tensor, + initial_c: cute.Tensor, + group_count: cutlass.Constexpr[int], + problem_shape_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + total_num_clusters: cutlass.Constexpr[int], + tensormap_cute_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr[int], + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + For grouped GEMM, tensor shapes, tensor strides, and tensor address are all provided + by different tensors in global memory. The "initial" tensors only carry data type and + majorness information. + + :param initial_a: Initial tensor A, used for data type and majorness information. + :type initial_a: cute.Tensor + :param initial_b: Initial tensor B, used for data type and majorness information. + :type initial_b: cute.Tensor + :param initial_c: Initial tensor C, used for data type and majorness information. + :type initial_c: cute.Tensor + :param group_count: The number of GEMM groups. + :type group_count: cutlass.Constexpr[int] + :param problem_shape_mnkl: Tensor containing the (M, N, K, L) shape for each group. + :type problem_shape_mnkl: cute.Tensor + :param strides_abc: Tensor containing the strides for A, B, and C for each group. + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing the base addresses for A, B, and C for each group. + :type tensor_address_abc: cute.Tensor + :param total_num_clusters: Total number of clusters needed for all groups. + :type total_num_clusters: cutlass.Constexpr[int] + :param tensormap_cute_tensor: Tensor for storing tensormaps. + :type tensormap_cute_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + :param stream: CUDA stream for asynchronous execution. + :type stream: cuda.CUstream + :raises TypeError: If A and B data types do not match. + """ + self.a_dtype = initial_a.element_type + self.b_dtype = initial_b.element_type + self.c_dtype = initial_c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(initial_a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(initial_b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(initial_c) + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + initial_a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + initial_b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + c_cta_v_layout = cute.composition( + cute.make_identity_layout(initial_c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + initial_c, + epi_smem_layout, + c_cta_v_layout, + ) + + self.tile_sched_params, grid = self._compute_grid( + total_num_clusters, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + self.size_tensormap_in_i64 = ( + 0 + if self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM + else GroupedGemmKernel.num_tensormaps + * GroupedGemmKernel.bytes_per_tensormap + // 8 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + tensormap_buffer: cute.struct.MemRange[ + cutlass.Int64, self.size_tensormap_in_i64 + ] + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.epi_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + group_count, + problem_shape_mnkl, + strides_abc, + tensor_address_abc, + tensormap_cute_tensor, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + epi_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + group_count: cutlass.Constexpr[int], + problem_sizes_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + ptrs_abc: cute.Tensor, + tensormaps: cute.Tensor, + ): + """ + GPU device kernel performing the grouped GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coord inside cluster + bid = cute.arch.block_idx() + mma_tile_coord_v = bid[0] % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: tensormap buffer, a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tensormap_a_smem_ptr = None + tensormap_b_smem_ptr = None + tensormap_c_smem_ptr = None + if cutlass.const_expr( + self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ): + tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() + tensormap_a_smem_ptr = tensormap_smem_ptr + tensormap_b_smem_ptr = ( + tensormap_a_smem_ptr + GroupedGemmKernel.bytes_per_tensormap // 8 + ) + tensormap_c_smem_ptr = ( + tensormap_b_smem_ptr + GroupedGemmKernel.bytes_per_tensormap // 8 + ) + ab_full_mbar_ptr = storage.ab_full_mbar_ptr.data_ptr() + ab_empty_mbar_ptr = storage.ab_empty_mbar_ptr.data_ptr() + acc_full_mbar_ptr = storage.acc_full_mbar_ptr.data_ptr() + acc_empty_mbar_ptr = storage.acc_empty_mbar_ptr.data_ptr() + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # init barrier for loading A, B with TMA + if warp_idx == self.epilog_warp_id[0]: + for k_stage in range(self.num_ab_stage): + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + with cute.arch.elect_one(): + cute.arch.mbarrier_init(ab_full_mbar_ptr + k_stage, 1) + cute.arch.mbarrier_init( + ab_empty_mbar_ptr + k_stage, num_tma_producer + ) + # Accumulator barrier init + if warp_idx == self.mma_warp_id: + for acc_stage in range(self.num_acc_stage): + with cute.arch.elect_one(): + cute.arch.mbarrier_init(acc_full_mbar_ptr + acc_stage, 1) + cute.arch.mbarrier_init( + acc_empty_mbar_ptr + acc_stage, 8 if use_2cta_instrs else 4 + ) + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full and empty + # + a_full_mcast_mask = None + b_full_mcast_mask = None + ab_empty_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + ab_empty_mcast_mask = a_full_mcast_mask | b_full_mcast_mask + acc_full_mcast_mask = None + if cutlass.const_expr(use_2cta_instrs): + acc_full_mcast_mask = cute.make_layout_image_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mode=0 + ) + block_in_cluster_coord_vmnk_peer = ( + block_in_cluster_coord_vmnk[0] ^ 1, + *block_in_cluster_coord_vmnk[1:], + ) + a_full_mcast_mask_peer = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2 + ) + b_full_mcast_mask_peer = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1 + ) + ab_empty_mcast_mask = ( + a_full_mcast_mask_peer + | b_full_mcast_mask_peer + | cutlass.Int16( + 0 if ab_empty_mcast_mask is None else ab_empty_mcast_mask + ) + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for load A, B with TMA + # + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta + ) + + # + # Get tensormap buffer address + # + grid_dim = cute.arch.grid_dim() + tensormap_workspace_idx = ( + bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0] + ) + + tensormap_manager = utils.TensorMapManager( + self.tensormap_update_mode, GroupedGemmKernel.bytes_per_tensormap + ) + tensormap_a_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 0, None)].iterator + ) + tensormap_b_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 1, None)].iterator + ) + tensormap_c_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 2, None)].iterator + ) + # Setup tensormap initialization pointer based on the mode + if cutlass.const_expr( + self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ): + tensormap_a_init_ptr = tensormap_a_smem_ptr + tensormap_b_init_ptr = tensormap_b_smem_ptr + tensormap_c_init_ptr = tensormap_c_smem_ptr + else: + tensormap_a_init_ptr = tensormap_a_ptr + tensormap_b_init_ptr = tensormap_b_ptr + tensormap_c_init_ptr = tensormap_c_ptr + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # Initialize tensormaps for A, B + if cutlass.const_expr(self.delegate_tensormap_ab_init == False): + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_init_ptr, self.tma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_init_ptr, self.tma_warp_id + ) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + tensormap_init_done = cutlass.Boolean(False) + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + # group index of last tile + last_group_idx = cutlass.Int32(-1) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z( + cur_tile_coord, + problem_sizes_mnkl, + ) + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_group_changed = cur_group_idx != last_group_idx + # skip tensormap update if we're working on the same group + if is_group_changed: + real_tensor_a = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.a_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 0, # 0 for tensor A + ) + real_tensor_b = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.b_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 1, # 1 for tensor B + ) + # wait tensormap initialization complete before update + if tensormap_init_done == False: + if cutlass.const_expr(self.delegate_tensormap_ab_init): + cute.arch.barrier( + barrier_id=self.tensormap_ab_init_bar_id, + number_of_threads=64, + ) + tensormap_manager.fence_tensormap_initialization() + tensormap_init_done = True + + tensormap_manager.update_tensormap( + (real_tensor_a, real_tensor_b), + (tma_atom_a, tma_atom_b), + (tensormap_a_ptr, tensormap_b_ptr), + self.tma_warp_id, + (tensormap_a_smem_ptr, tensormap_b_smem_ptr), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + num_prev_k_blk = total_k_tile_cnt + total_k_tile_cnt += cur_k_tile_cnt + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + tma_wr_k_tile = cutlass.Int32(0) + smem_wr_buffer = (num_prev_k_blk + tma_wr_k_tile) % self.num_ab_stage + tma_wr_ab_empty_phase = ( + num_prev_k_blk + tma_wr_k_tile + ) // self.num_ab_stage % 2 ^ 1 + peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( + tma_wr_k_tile < cur_k_tile_cnt, + ab_empty_mbar_ptr + smem_wr_buffer, + tma_wr_ab_empty_phase, + ) + # ensure the update to tensormap has completed before using it + if is_group_changed: + tensormap_manager.fence_tensormap_update(tensormap_a_ptr) + tensormap_manager.fence_tensormap_update(tensormap_b_ptr) + # + # Tma load loop + # + for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1): + tma_wr_k_tile_next = tma_wr_k_tile + 1 + smem_wr_buffer_next = ( + num_prev_k_blk + tma_wr_k_tile_next + ) % self.num_ab_stage + tma_wr_ab_empty_phase_next = ( + tma_wr_ab_empty_phase ^ 1 + if smem_wr_buffer_next == 0 + else tma_wr_ab_empty_phase + ) + + smem_full_mbar_ptr = ab_full_mbar_ptr + smem_wr_buffer + + # Wait for AB buffer empty + if peek_ab_empty_status == 0: + cute.arch.mbarrier_wait( + ab_empty_mbar_ptr + smem_wr_buffer, tma_wr_ab_empty_phase + ) + + # Arrive AB buffer and expect full transaction bytes + if is_leader_cta: + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + smem_full_mbar_ptr, self.num_tma_load_bytes + ) + + # Load A/B with TMA + cute.copy( + tma_atom_a, + tAgA_slice[(None, tma_wr_k_tile)], + tAsA[(None, smem_wr_buffer)], + tma_bar_ptr=smem_full_mbar_ptr, + mcast_mask=a_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_a_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, tma_wr_k_tile)], + tBsB[(None, smem_wr_buffer)], + tma_bar_ptr=smem_full_mbar_ptr, + mcast_mask=b_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_b_ptr, + cute.AddressSpace.generic, + ), + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( + tma_wr_k_tile_next < cur_k_tile_cnt, + ab_empty_mbar_ptr + smem_wr_buffer_next, + tma_wr_ab_empty_phase_next, + ) + + tma_wr_k_tile = tma_wr_k_tile_next + smem_wr_buffer = smem_wr_buffer_next + tma_wr_ab_empty_phase = tma_wr_ab_empty_phase_next + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # initialize tensormap A, B for TMA warp + if cutlass.const_expr(self.delegate_tensormap_ab_init): + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_init_ptr, self.mma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_init_ptr, self.mma_warp_id + ) + # signal tensormap initialization has finished + cute.arch.barrier( + barrier_id=self.tensormap_ab_init_bar_id, number_of_threads=64 + ) + # Bar sync for retrieve tmem ptr from shared mem + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + + work_tile = tile_sched.initial_work_tile_info() + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + # MMA warp is only interested in number of tiles along K dimension + ( + cur_k_tile_cnt, + cur_group_idx, + ) = group_gemm_ts_helper.search_cluster_tile_count_k( + cur_tile_coord, + problem_sizes_mnkl, + ) + # Set tensor memory buffer for current tile + acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_buf_idx)] + + num_prev_k_blk = total_k_tile_cnt + total_k_tile_cnt += cur_k_tile_cnt + + # Peek (try_wait) AB buffer full for k_tile = 0 + mma_rd_k_tile = cutlass.Int32(0) + smem_rd_buffer = (num_prev_k_blk + mma_rd_k_tile) % self.num_ab_stage + need_check_rd_buffer_full = ( + mma_rd_k_tile < cur_k_tile_cnt and is_leader_cta + ) + mma_rd_ab_full_phase = ( + (num_prev_k_blk + mma_rd_k_tile) // self.num_ab_stage % 2 + ) + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( + need_check_rd_buffer_full, + ab_full_mbar_ptr + smem_rd_buffer, + mma_rd_ab_full_phase, + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_empty_phase = ( + tile_sched.num_tiles_executed // self.num_acc_stage % 2 ^ 1 + ) + cute.arch.mbarrier_wait( + acc_empty_mbar_ptr + acc_buf_idx, acc_empty_phase + ) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(cur_k_tile_cnt): + mma_rd_k_tile_next = cutlass.Int32(k_tile + 1) + smem_rd_buffer_next = ( + num_prev_k_blk + mma_rd_k_tile_next + ) % self.num_ab_stage + mma_rd_ab_full_phase_next = ( + mma_rd_ab_full_phase ^ 1 + if smem_rd_buffer_next == 0 + else mma_rd_ab_full_phase + ) + if is_leader_cta: + # Wait for AB buffer full + if peek_ab_full_status == 0: + cute.arch.mbarrier_wait( + ab_full_mbar_ptr + smem_rd_buffer, mma_rd_ab_full_phase + ) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, smem_rd_buffer) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + with cute.arch.elect_one(): + tcgen05.commit( + ab_empty_mbar_ptr + smem_rd_buffer, + ab_empty_mcast_mask, + self.cta_group, + ) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + need_check_rd_buffer_full = ( + mma_rd_k_tile_next < cur_k_tile_cnt and is_leader_cta + ) + + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( + need_check_rd_buffer_full, + ab_full_mbar_ptr + smem_rd_buffer_next, + mma_rd_ab_full_phase_next, + ) + + mma_rd_k_tile = mma_rd_k_tile_next + smem_rd_buffer = smem_rd_buffer_next + mma_rd_ab_full_phase = mma_rd_ab_full_phase_next + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + with cute.arch.elect_one(): + tcgen05.commit( + acc_full_mbar_ptr + acc_buf_idx, + acc_full_mcast_mask, + self.cta_group, + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # initialize tensorap for C + tensormap_manager.init_tensormap_from_atom( + tma_atom_c, + tensormap_c_init_ptr, + self.epilog_warp_id[0], + ) + # Alloc tensor memory buffer + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + epi_tidx = tidx + # + # Partition for epilogue + # + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition(tma_atom_c, tCgC, epi_tile, sC) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + + work_tile = tile_sched.initial_work_tile_info() + # wait tensormap initialization complete before update + tensormap_manager.fence_tensormap_initialization() + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + # group index of last tile + last_group_idx = cutlass.Int32(-1) + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z( + cur_tile_coord, + problem_sizes_mnkl, + ) + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_group_changed = cur_group_idx != last_group_idx + if is_group_changed: + # construct tensor C based on real address, shape and stride information + real_tensor_c = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.c_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 2, # 2 for tensor C + ) + tensormap_manager.update_tensormap( + ((real_tensor_c),), + ((tma_atom_c),), + ((tensormap_c_ptr),), + self.epilog_warp_id[0], + (tensormap_c_smem_ptr,), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + total_k_tile_cnt += cur_k_tile_cnt + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_buf_idx)] + + # + # Wait for accumulator buffer full + # + acc_full_phase = tile_sched.num_tiles_executed // self.num_acc_stage % 2 + cute.arch.mbarrier_wait(acc_full_mbar_ptr + acc_buf_idx, acc_full_phase) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + # ensure the update to tensormap has completed before using it + if is_group_changed: + if warp_idx == self.epilog_warp_id[0]: + tensormap_manager.fence_tensormap_update(tensormap_c_ptr) + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to output type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + tRS_rC.store(acc_vec.to(self.c_dtype)) + # + # Store C to shared memory + # + epi_buffer = (num_prev_subtiles + subtile_idx) % self.num_epi_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, epi_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + # + # store C to global memory with TMA + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, epi_buffer)], + bSG_gC[(None, subtile_idx)], + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_c_ptr, + cute.AddressSpace.generic, + ), + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group( + self.num_epi_stage - 1, read=True + ) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive( + acc_empty_mbar_ptr + acc_buf_idx, + cta_rank_in_cluster // 2 * 2 if use_2cta_instrs else None, + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads + ) + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + + # + # Wait a/b buffer empty + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.mbarrier_wait( + (ab_empty_mbar_ptr + ((total_k_tile_cnt - 1) % self.num_ab_stage)), + (((total_k_tile_cnt - 1) // self.num_ab_stage) % 2), + ) + + @cute.jit + def make_tensor_for_tensormap_update( + self, + group_idx: cutlass.Int32, + dtype: Type[cutlass.Numeric], + problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + tensor_index: int, + ): + """Extract stride and tensor address for a given group and construct a global tensor. + + This function is used within the kernel to dynamically create a CUTE tensor + representing A, B, or C for the current group being processed, using the + group-specific address, shape, and stride information. + + :param group_idx: The index of the current group within the grouped GEMM. + :type group_idx: cutlass.Int32 + :param dtype: The data type of the tensor elements (e.g., cutlass.Float16). + :type dtype: Type[cutlass.Numeric] + :param problem_shape_mnk: The (M, N, K) problem shape for the current group. + :type problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + :param strides_abc: Tensor containing strides for A, B, C for all groups. Layout: (group_count, 3, 2). + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing global memory addresses for A, B, C for all groups. Layout: (group_count, 3). + :type tensor_address_abc: cute.Tensor + :param tensor_index: Specifies which tensor to create: 0 for A, 1 for B, 2 for C. + :type tensor_index: int + :return: A CUTE tensor representing the requested global memory tensor (A, B, or C) for the specified group. + :rtype: cute.Tensor + :raises TypeError: If the provided dtype is not a subclass of cutlass.Numeric. + """ + ptr_i64 = tensor_address_abc[(group_idx, tensor_index)] + if cutlass.const_expr( + not isclass(dtype) or not issubclass(dtype, cutlass.Numeric) + ): + raise TypeError( + f"dtype must be a type of cutlass.Numeric, got {type(dtype)}" + ) + tensor_gmem_ptr = cute.make_ptr( + dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + + strides_tensor_gmem = strides_abc[(group_idx, tensor_index, None)] + strides_tensor_reg = cute.make_fragment( + cute.make_layout(2), + strides_abc.element_type, + ) + cute.autovec_copy(strides_tensor_gmem, strides_tensor_reg) + stride_mn = strides_tensor_reg[0] + stride_k = strides_tensor_reg[1] + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + if cutlass.const_expr(tensor_index == 0): # tensor A + m = problem_shape_mnk[0] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, k, c1), stride=(stride_mn, stride_k, c0)), + ) + elif cutlass.const_expr(tensor_index == 1): # tensor B + n = problem_shape_mnk[1] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((n, k, c1), stride=(stride_mn, stride_k, c0)), + ) + else: # tensor C + m = problem_shape_mnk[0] + n = problem_shape_mnk[1] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, n, c1), stride=(stride_mn, stride_k, c0)), + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load(t2r) + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tma_atom_c: cute.CopyAtom, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to partition + shared memory (source) and global memory (destination) for TMA store version. + + :param tma_atom_c: The TMA copy atom configured for storing tensor C. + :type tma_atom_c: cute.CopyAtom + :param gC_mnl: The global memory tensor C. + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler defining the granularity of the operation. + :type epi_tile: cute.Tile + :param sC: The shared memory epilogue buffer tensor. + :type sC: cute.Tensor + :return: A tuple containing: + - tma_atom_c: The input TMA copy atom (passed through). + - bSG_sC: The source shared memory tensor partitioned for the TMA operation. + - tCgC: The destination global memory tensor partitioned for the TMA operation. + :rtype: tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + ) -> tuple[int, int, int]: + """Computes the number of stages for accumulator, A/B operands, and epilogue based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C in global memory. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (accumulator stages, A/B operand stages, epilogue stages) + :rtype: tuple[int, int, int] + """ + # Default accumulator and epilogue stages + num_acc_stage = 2 + num_epi_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and Epilogue + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # stage=1 + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # stage=1 + ) + epi_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, # stage=1 + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + epi_bytes_per_stage = cute.size_in_bytes(c_dtype, epi_smem_layout_staged_one) + epi_bytes = epi_bytes_per_stage * num_epi_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial epilogue bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity // occupancy + - GroupedGemmKernel.reserved_smem_bytes + - epi_bytes + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + remaining_smem = ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (GroupedGemmKernel.reserved_smem_bytes + epi_bytes) + ) + num_epi_stage += remaining_smem // (occupancy * epi_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_epi_stage + + @staticmethod + def _compute_grid( + total_num_clusters: int, + cluster_shape_mn: tuple[int, int], + max_active_clusters: cutlass.Constexpr[int], + ) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]: + """Compute tile scheduler parameters and grid shape for grouped GEMM operations. + + :param total_num_clusters: Total number of clusters to process across all groups. + :type total_num_clusters: int + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: tuple[utils.PersistentTileSchedulerParams, tuple[int, ...]] + """ + # Create problem shape with M, N dimensions from cluster shape + # and L dimension representing the total number of clusters. + problem_shape_ntile_mnl = ( + cluster_shape_mn[0], + cluster_shape_mn[1], + cutlass.Int32(total_num_clusters), + ) + + tile_sched_params = utils.PersistentTileSchedulerParams( + problem_shape_ntile_mnl, (*cluster_shape_mn, 1) + ) + + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_mbar_smem_bytes(**kwargs_stages: int) -> int: + """Calculate shared memory consumption for memory barriers based on provided stages. + + Each stage requires 2 barriers, and each barrier consumes 8 bytes of shared memory. + The total consumption is the sum across all provided stages. This function calculates the total + shared memory needed for these barriers. + + :param kwargs_stages: Variable keyword arguments where each key is a stage name + (e.g., num_acc_stage, num_ab_stage) and each value is the + number of stages of that type. + :type kwargs_stages: int + :return: Total shared memory bytes required for all memory barriers. + :rtype: int + """ + num_barriers_per_stage = 2 + num_bytes_per_barrier = 8 + mbar_smem_consumption = sum( + [ + num_barriers_per_stage * num_bytes_per_barrier * stage + for stage in kwargs_stages.values() + ] + ) + return mbar_smem_consumption + + @staticmethod + def _get_tensormap_smem_bytes( + tensormap_update_mode: utils.TensorMapUpdateMode, + ) -> int: + """Get the SMEM consumption for the tensormap buffer based on the update mode. + + :param tensormap_update_mode: Specifies whether tensormaps are updated in GMEM or SMEM. + :type tensormap_update_mode: utils.TensorMapUpdateMode + :return: The shared memory bytes required for the tensormap buffer. Returns 0 if mode is GMEM. + :rtype: int + :raises ValueError: If an invalid tensormap update mode is provided. + """ + if tensormap_update_mode == utils.TensorMapUpdateMode.GMEM: + return 0 + elif tensormap_update_mode == utils.TensorMapUpdateMode.SMEM: + return ( + GroupedGemmKernel.bytes_per_tensormap * GroupedGemmKernel.num_tensormaps + ) + else: + raise ValueError(f"Invalid tensormap update mode: {tensormap_update_mode}") + + @staticmethod + def _get_tensor_smem_bytes( + a_smem_layout_staged: cute.Layout, + a_dtype: Type[cutlass.Numeric], + b_smem_layout_staged: cute.Layout, + b_dtype: Type[cutlass.Numeric], + epi_smem_layout_staged: cute.Layout, + c_dtype: Type[cutlass.Numeric], + ) -> int: + """Compute the total SMEM consumption for tensor A, B and C.""" + ab_bytes = cute.size_in_bytes( + a_dtype, a_smem_layout_staged + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged) + + epi_bytes = cute.size_in_bytes(c_dtype, epi_smem_layout_staged) + return ab_bytes + epi_bytes + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, + mma_tiler: tuple[int, int, int], + num_acc_stage: int, + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + :param acc_stage: The stage of the accumulator tensor. + :type acc_stage: int + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage)) + num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + return num_tmem_alloc_cols + + # Size of smem we reserved for mbarrier, tensor memory management and tensormap update + reserved_smem_bytes = 1024 + bytes_per_tensormap = 128 + num_tensormaps = 3 + # size of smem used for tensor memory management + tensor_memory_management_bytes = 12 + + +# Create tensor and return the pointer, tensor, and stride +def create_tensor_and_stride( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + dtype: type[cutlass.Numeric], + is_dynamic_layout: bool = True, + torch_tensor_cpu: torch.Tensor = None, +) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: + """Create a GPU tensor from scratch or based on an existing CPU tensor. + + :param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one. + :type torch_tensor_cpu: torch.Tensor, optional + """ + if torch_tensor_cpu is None: + # Create new CPU tensor + torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype) + + # Create GPU tensor from CPU tensor (new or existing) + cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( + torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16 + ) + return ( + torch_tensor.data_ptr(), + torch_tensor, + cute_tensor, + torch_tensor_cpu, + torch_tensor.stride()[:-1], + ) + + +def create_tensors_for_all_groups( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + torch_fp32_tensors_abc: List[List[torch.Tensor]] = None, +) -> tuple[ + List[List[int]], + List[List[torch.Tensor]], + List[tuple], + List[List[tuple]], + List[List[torch.Tensor]], +]: + if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len( + problem_sizes_mnkl + ): + raise ValueError("torch_fp32_tensors_abc must have one entry per group") + + # Initialize lists to store tensors for all groups + new_torch_fp32_tensors_abc = ( + [] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc + ) + torch_tensors_abc = [] + cute_tensors_abc = [] + strides_abc = [] + ptrs_abc = [] + + # Iterate through all groups and create tensors for each group + for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): + # Get existing CPU tensors if available, otherwise None + existing_cpu_a = ( + torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None + ) + existing_cpu_b = ( + torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None + ) + existing_cpu_c = ( + torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None + ) + + # Create tensors (reusing CPU tensors if provided) + ( + ptr_a, + torch_tensor_a, + cute_tensor_a, + tensor_fp32_a, + stride_mk_a, + ) = create_tensor_and_stride( + l, m, k, a_major == "m", ab_dtype, torch_tensor_cpu=existing_cpu_a + ) + ( + ptr_b, + torch_tensor_b, + cute_tensor_b, + tensor_fp32_b, + stride_nk_b, + ) = create_tensor_and_stride( + l, n, k, b_major == "n", ab_dtype, torch_tensor_cpu=existing_cpu_b + ) + ( + ptr_c, + torch_tensor_c, + cute_tensor_c, + tensor_fp32_c, + stride_mn_c, + ) = create_tensor_and_stride( + l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c + ) + + # Only append to new_torch_fp32_tensors_abc if we created new CPU tensors + if torch_fp32_tensors_abc is None: + new_torch_fp32_tensors_abc.append( + [tensor_fp32_a, tensor_fp32_b, tensor_fp32_c] + ) + + ptrs_abc.append([ptr_a, ptr_b, ptr_c]) + torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c]) + strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) + cute_tensors_abc.append( + ( + cute_tensor_a, + cute_tensor_b, + cute_tensor_c, + ) + ) + + return ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + new_torch_fp32_tensors_abc, + ) + + +def run( + num_groups: int, + problem_sizes_mnkl: tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + tensormap_update_mode: utils.TensorMapUpdateMode, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + """Run grouped GEMM example with specified configurations. + + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + print(f"Running Blackwell Grouped GEMM test with:") + print(f"{num_groups} groups") + for i, (m, n, k, l) in enumerate(problem_sizes_mnkl): + print(f"Group {i}: {m}x{n}x{k}x{l}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Tensor map update mode: {tensormap_update_mode}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Skip unsupported types + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + }: + raise ValueError(f"Skip unsupported ab_dtype {ab_dtype}") + if c_dtype not in {cutlass.Float16, cutlass.BFloat16, cutlass.Float32}: + raise ValueError(f"Skip unsupported c_dtype {c_dtype}") + # Skip unsupported acc dtype + if acc_dtype not in {cutlass.Float32, cutlass.Float16}: + raise ValueError(f"Skip unsupported acc_dtype {acc_dtype}") + # Skip invalid ab_dtype and acc_dtype combination + if ab_dtype == cutlass.BFloat16 and acc_dtype == cutlass.Float16: + raise ValueError("Skip invalid ab_dtype and acc_dtype combination") + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + raise ValueError(f"Skip invalid mma tiler M {mma_tiler_mn[0]}") + if mma_tiler_mn[1] not in range(32, 257, 32): + raise ValueError(f"Skip invalid mma tiler N {mma_tiler_mn[1]}") + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + raise ValueError( + f"cluster_shape_m need align with use_2cta_instrs config {cluster_shape_mn}" + ) + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + raise ValueError(f"Skip invalid cluster shape {cluster_shape_mn}") + + # Skip illegal problem shape for load/store alignment + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + raise ValueError("Skip invalid problem alignment") + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + # Create tensors for all groups using the new function + ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + torch_fp32_tensors_abc, + ) = create_tensors_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + ) + + # Choose A, B, C with the smallest size to create initial tensormaps + key_size_a = lambda item: item[1][0] * item[1][2] + key_size_b = lambda item: item[1][1] * item[1][2] + key_size_c = lambda item: item[1][0] * item[1][1] + # Find the indices of the groups with the smallest tensor sizes + min_a_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_a) + min_b_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_b) + min_c_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_c) + initial_cute_tensors_abc = [ + cute_tensors_abc[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc[min_c_idx][2], # C with smallest (m, n) + ] + + hardware_info = utils.HardwareInfo() + sm_count = hardware_info.get_max_active_clusters(1) + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + # Prepare tensormap buffer for each SM + num_tensormap_buffers = sm_count + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + grouped_gemm = GroupedGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + tensormap_update_mode, + ) + + # layout (num_groups, 4):(4, 1) + ( + tensor_of_dim_size_mnkl, + tensor_of_dim_size_mnkl_torch, + ) = cutlass_torch.cute_tensor_like( + torch.tensor(problem_sizes_mnkl, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + # layout (num_groups, 3, 2):(6, 2, 1) + tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups,3):(3, 1) + tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + # Compute total number of cluster tiles we need to compute for given grouped GEMM problem + def compute_total_num_clusters( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + cluster_tile_shape_mn: tuple[int, int], + ) -> int: + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + ) -> tuple[int, int]: + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + mma_tiler_mn, cluster_shape_mn, use_2cta_instrs + ) + total_num_clusters = compute_total_num_clusters( + problem_sizes_mnkl, cluster_tile_shape_mn + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + # Compile grouped GEMM kernel + compiled_grouped_gemm = cute.compile( + grouped_gemm, + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + num_groups, + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + total_num_clusters, + tensor_of_tensormap, + max_active_clusters, + current_stream, + ) + + if not skip_ref_check: + compiled_grouped_gemm( + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_tensormap, + current_stream, + ) + + # Compute reference result + for i, (a, b, c) in enumerate(torch_tensors_abc): + ref = torch.einsum( + "mkl,nkl->mnl", + a.cpu().to(dtype=torch.float32), + b.cpu().to(dtype=torch.float32), + ) + print(f"checking group {i}") + torch.testing.assert_close( + c.cpu(), + ref.to(cutlass_torch.dtype(c_dtype)), + atol=tolerance, + rtol=1e-05, + ) + + def generate_tensors(): + # Reuse existing CPU tensors and create new GPU tensors from them + ( + ptrs_abc_workspace, + torch_tensors_abc_workspace, + cute_tensors_abc_workspace, + strides_abc_workspace, + _, + ) = create_tensors_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + torch_fp32_tensors_abc, + ) + + initial_cute_tensors_abc_workspace = [ + cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n) + ] + + # Create new tensors for this workspace + tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc_workspace, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc_workspace, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensormap_workspace, _ = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + return testing.JitArguments( + initial_cute_tensors_abc_workspace[0], + initial_cute_tensors_abc_workspace[1], + initial_cute_tensors_abc_workspace[2], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc_workspace, + tensor_of_ptrs_abc_workspace, + tensormap_workspace, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + sum( + [ + sum( + [ + torch_tensor.numel() * torch_tensor.element_size() + for torch_tensor in group_tensors + ] + ) + for group_tensors in torch_tensors_abc + ] + ) + + + # Add size of strides tensor + tensor_of_strides_abc_torch.numel() + * tensor_of_strides_abc_torch.element_size() + + + # Add size of ptrs tensor + tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size() + + + # Add size of tensormap tensor + tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_grouped_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_comma_separated_tuples(s: str) -> List[tuple[int, ...]]: + if s.strip().startswith("("): + # Split on ),( to separate tuples + tuples = s.strip("()").split("),(") + result = [] + tuple_len = None + + for t in tuples: + # Parse individual tuple + nums = [int(x.strip()) for x in t.split(",")] + + # Validate tuple length consistency + if tuple_len is None: + tuple_len = len(nums) + elif len(nums) != tuple_len: + raise argparse.ArgumentTypeError( + "All tuples must have the same length" + ) + + result.append(tuple(nums)) + return result + + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers or list of tuples" + ) + + parser = argparse.ArgumentParser( + description="Example of Grouped GEMM on Blackwell." + ) + parser.add_argument( + "--num_groups", + type=int, + default=2, + help="Number of groups", + ) + parser.add_argument( + "--problem_sizes_mnkl", + type=parse_comma_separated_tuples, + default=((128, 128, 128, 1), (128, 128, 128, 1)), + help="a tuple of problem sizes for each group (comma-separated tuples)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument( + "--tensormap_update_mode", + type=str, + default="SMEM", + help="Tensor map update mode", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if ( + len(args.problem_sizes_mnkl) != 0 + and len(args.problem_sizes_mnkl) != args.num_groups + ): + parser.error("--problem_sizes_mnkl must contain exactly num_groups tuples") + + # l mode must be 1 for all groups + for _, _, _, l in args.problem_sizes_mnkl: + if l != 1: + parser.error("l must be 1 for all groups") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + if args.tensormap_update_mode not in ["GMEM", "SMEM"]: + parser.error("--tensormap_update_mode must be GMEM or SMEM") + + if args.tensormap_update_mode == "GMEM": + tensormap_update_mode = utils.TensorMapUpdateMode.GMEM + else: + tensormap_update_mode = utils.TensorMapUpdateMode.SMEM + + torch.manual_seed(2025) + + run( + args.num_groups, + args.problem_sizes_mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + tensormap_update_mode, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py new file mode 100644 index 0000000000000000000000000000000000000000..829d6b7ecaf456c0b5782a82b0f8621370c1e532 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py @@ -0,0 +1,3784 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse +from typing import List, Type, Tuple, Optional +import cuda.bindings.driver as cuda + +import torch +import torch.nn.functional as F + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent)) +from mamba2_ssd_reference import ( + ssd_reference_fp32_all, + ssd_reference_lowprecision_intermediates, + analyze_relative_diffs, +) +from mamba2_ssd_tile_scheduler import ( + Mamba2SSDTileSchedulerParams, + Mamba2SSDTileScheduler, +) + + +class SSDKernel: + def __init__( + self, + io_dtype: Type[cutlass.Numeric], + cumsum_delta_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + L: int, + D: int, + N: int, + has_d: bool, + d_has_hdim: bool, + ): + self.io_dtype: Type[cutlass.Numeric] = io_dtype + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.cumsum_delta_dtype: Type[cutlass.Numeric] = cumsum_delta_dtype + # has_d means epilog warp performs Y += X*D fusion + self.has_d: bool = has_d + # d_has_hdim = True means D is (D, EH) shape and loaded by TMA + # d_has_hdim = False means D is (1, EH) shape and loaded directly to register + self.d_has_hdim: bool = d_has_hdim + self.tile_shape = (L, D, N) + + assert io_dtype in { + cutlass.Float16, + cutlass.BFloat16, + }, "Do not support other I/O types." + assert acc_dtype in {cutlass.Float32}, "Do not support other ACC types." + assert cumsum_delta_dtype in { + cutlass.Float32 + }, "Do not support other cumsum types." + assert not (not has_d and d_has_hdim), "D cannot have Hdim if has_d is False" + + # Hardcode default setting + self.use_2cta_instrs = False + self.cluster_shape_mnk = (1, 1, 1) + self.epi_tile = (128, 32) + + # Setup mma tile shapes + self.tile_shape_mnk_intra1 = (L, L, N) + self.tile_shape_mnk_intra2 = (L, D, L) + self.tile_shape_mnk_inter1 = (N, D, L) + self.tile_shape_mnk_inter2 = (L, D, N) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + # Launch config + self.occupancy = 1 + self.mma_inter_warp_id = 0 + self.mma_intra_warp_id = 1 + self.tma_b_c_warp_id = 2 + self.tma_deltas_x_d_warp_id = 3 + self.pre_inter_warp_id = [4, 5, 6, 7] + self.pre_intra_warp_id = [8, 9, 10, 11] + self.epilog_warp_id = [12, 13, 14, 15] + self.threads_per_cta = 32 * len( + ( + self.mma_inter_warp_id, + self.mma_intra_warp_id, + self.tma_b_c_warp_id, + self.tma_deltas_x_d_warp_id, + *self.pre_inter_warp_id, + *self.pre_intra_warp_id, + *self.epilog_warp_id, + ) + ) + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + # Named barriers + self.pre_inter_sync_bar_id = 1 + self.epilog_sync_bar_id = 2 + self.pre_intra_sync_bar_id = 3 + self.tmem_dealloc_sync_bar_id = 4 + + # Number of registers used by each warp + self.num_regs_uniform_warps = 24 + self.num_regs_pre_inter_warps = 168 + self.num_regs_pre_intra_warps = 208 + self.num_regs_epilogue_warps = 112 + + # Shared storage + self.shared_storage = None + + # TMEM buffer offsets + self.tmem_intra1_acc_offset = 0 + self.tmem_intra2_q_offset = 0 + self.tmem_intra2_acc_offset = 0 + self.tmem_inter1_acc_offset = 0 + self.tmem_inter2_acc_offset = 0 + self.num_tmem_cols_total = 0 + + def _setup_attributes(self): + ( + tiled_mma_intra1, + tiled_mma_intra2, + tiled_mma_inter1, + tiled_mma_inter2, + ) = self.make_tiled_mmas( + self.io_dtype, + self.acc_dtype, + self.cta_group, + self.tile_shape_mnk_intra1, + self.tile_shape_mnk_intra2, + self.tile_shape_mnk_inter1, + self.tile_shape_mnk_inter2, + ) + + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_intra1.thr_id.shape,), + ) + + # Setup stages + ( + self.input_stages, + self.output_stages, + self.internal_stages, + self.intra1_acc_stages, + ) = self._compute_stages( + self.smem_capacity, + ) + + # Setup smem layouts + # X is B operand (from smem) of INTRA2_MMA and INTER1_MMA + self.x_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + self.io_dtype, + self.input_stages, + ) + self.num_x_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.x_smem_layout, (None, None, None, 0)) + ) + + # XT is same shape as ACC operand of INTER2_MMA, before postprocessing by EPILOG + self.xt_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + self.tile_shape_mnk_intra2[:2], + self.input_stages, + ) + + # B is B operand (from smem) of INTRA1_MMA + self.b_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + self.io_dtype, + self.input_stages, + ) + self.num_b_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.b_smem_layout, (None, None, None, 0)) + ) + + # B_INTERNAL is also A operand (from smem) of INTER1_MMA, after preprocessed by PRE_INTER + self.bt_internal_smem_layout = sm100_utils.make_smem_layout_a( + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + self.io_dtype, + self.internal_stages, + ) + + # B needs to be proprocessed to be used as A operand of INTER1_MMA + self.bt_smem_layout = cute.coalesce( + sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.ROW_MAJOR, + (self.tile_shape_mnk_inter1[0], self.tile_shape_mnk_inter1[2]), + self.input_stages, + ), + target_profile=(1, 1, 1), + ) + + # C is A operand (from smem) of INTRA1_MMA and INTER2_MMA + self.c_smem_layout = sm100_utils.make_smem_layout_a( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + self.io_dtype, + self.input_stages, + ) + self.num_c_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.c_smem_layout, (None, None, None, 0)) + ) + + # P is B operand (from smem) of INTER2_MMA, after preprocessed by PRE_INTER + self.p_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + self.io_dtype, + self.internal_stages, + ) + + # PT is ACC operand (from tmem) of INTER1_MMA, after postprocessed by PRE_INTER + self.pt_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + self.tile_shape_mnk_inter1[:2], + self.internal_stages, + ) + + # Q is A operand (from tmem) of INTRA2_MMA, after preprocessed by PRE_INTRA + self.q_tmem_layout = sm100_utils.make_smem_layout_a( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + self.io_dtype, + self.internal_stages, + ) + + # P is ACC operand (from tmem) of INTER1_MMA, to be TMA stored by PRE_INTER + self.p_smem_layout_store = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.ROW_MAJOR, + self.tile_shape_mnk_inter2[1:], + self.internal_stages, + ) + + # Y is ACC operand (from smem) of INTER2_MMA and INTRA2_MMA, after postprocessed and TMA stored by EPILOG + self.y_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + self.epi_tile, + self.output_stages, + ) + + # Delta is linear smem layouts for pre/post processing + self.delta_linear_smem_layout = cute.make_layout( + (self.tile_shape_mnk_inter1[2], self.input_stages) + ) + self.num_delta_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.delta_linear_smem_layout, (None, 0)) + ) + + # Cumsum delta is linear smem layouts for pre/post processing + self.cumsum_delta_linear_smem_layout = cute.make_layout( + (self.tile_shape_mnk_inter1[2], self.input_stages) + ) + self.num_cumsum_delta_load_bytes = cute.size_in_bytes( + self.cumsum_delta_dtype, + cute.slice_(self.cumsum_delta_linear_smem_layout, (None, 0)), + ) + + # D is linear smem layouts when d_has_hdim is True + self.d_linear_smem_layout = ( + cute.make_layout((self.tile_shape_mnk_inter2[1], self.input_stages)) + if self.d_has_hdim + else None + ) + self.num_d_load_bytes = ( + cute.size_in_bytes( + self.io_dtype, + cute.slice_(self.d_linear_smem_layout, (None, 0)), + ) + if self.d_has_hdim + else 0 + ) + + # Setup tmem offsets + ( + self.tmem_intra1_acc_offset, + self.tmem_intra2_q_offset, + self.tmem_intra2_acc_offset, + self.tmem_inter1_acc_offset, + self.tmem_inter2_acc_offset, + self.num_tmem_cols_total, + ) = self._plan_tmem_offsets( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + self.internal_stages, + self.q_tmem_layout, + self.io_dtype, + self.internal_stages, + self.intra1_acc_stages, + ) + + return + + @cute.jit + def __call__( + self, + x: cute.Tensor, + cumsum_delta: cute.Tensor, + delta: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + y: cute.Tensor, + fstate: cute.Tensor, + d: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + self._setup_attributes() + ( + tiled_mma_intra1, + tiled_mma_intra2, + tiled_mma_inter1, + tiled_mma_inter2, + ) = self.make_tiled_mmas( + self.io_dtype, + self.acc_dtype, + self.cta_group, + self.tile_shape_mnk_intra1, + self.tile_shape_mnk_intra2, + self.tile_shape_mnk_inter1, + self.tile_shape_mnk_inter2, + ) + + # Setup TMA atoms and convert TMA tensors + # TMA load for A + x_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, tiled_mma_intra2.thr_id + ) + tma_atom_x, tma_tensor_x = cute.nvgpu.make_tiled_tma_atom_B( + x_op, + x, + cute.slice_(self.x_smem_layout, (None, None, None, 0)), + self.tile_shape_mnk_intra2, + tiled_mma_intra2, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if x.element_type is cutlass.Float32 else None + ), + ) + + # TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, tiled_mma_intra1.thr_id + ) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + cute.slice_(self.b_smem_layout, (None, None, None, 0)), + self.tile_shape_mnk_intra1, + tiled_mma_intra1, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + # TMA load for C + c_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mnk, tiled_mma_intra1.thr_id + ) + tma_atom_c, tma_tensor_c = cute.nvgpu.make_tiled_tma_atom_A( + c_op, + c, + cute.slice_(self.c_smem_layout, (None, None, None, 0)), + self.tile_shape_mnk_intra1, + tiled_mma_intra1, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if c.element_type is cutlass.Float32 else None + ), + ) + + # TMA load for delta + # TODO: use bulkcp instead of tma + delta_cta_v_layout = cute.slice_( + cute.make_identity_layout(delta.shape), (None, 0, 0, 0) + ) + delta_linear_smem_layout = cute.slice_(self.delta_linear_smem_layout, (None, 0)) + tma_atom_delta, tma_tensor_delta = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + delta, + delta_linear_smem_layout, + delta_cta_v_layout, + ) + + # TMA load for cumsum_delta + cumsum_delta_cta_v_layout = cute.slice_( + cute.make_identity_layout(cumsum_delta.shape), (None, 0, 0, 0) + ) + cumsum_delta_linear_smem_layout = cute.slice_( + self.cumsum_delta_linear_smem_layout, (None, 0) + ) + ( + tma_atom_cumsum_delta, + tma_tensor_cumsum_delta, + ) = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + cumsum_delta, + cumsum_delta_linear_smem_layout, + cumsum_delta_cta_v_layout, + ) + + tma_atom_d = None + tma_tensor_d = d + # TMA load for D + if cutlass.const_expr(self.d_has_hdim): + d_cta_v_layout = cute.slice_(cute.make_identity_layout(d.shape), (None, 0)) + d_linear_smem_layout = cute.slice_(self.d_linear_smem_layout, (None, 0)) + ( + tma_atom_d, + tma_tensor_d, + ) = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + d, + d_linear_smem_layout, + d_cta_v_layout, + ) + + # TMA store for y + y_cta_v_layout = cute.composition( + cute.make_identity_layout(y.shape), self.epi_tile + ) + y_smem_layout = cute.slice_(self.y_smem_layout, (None, None, 0)) + tma_atom_y, tma_tensor_y = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + y, + y_smem_layout, + y_cta_v_layout, + ) + + # TMA store for fstate(p) + p_cta_v_layout = cute.slice_( + cute.make_identity_layout(fstate.shape), (None, None, 0, 0) + ) + p_smem_layout_store = cute.slice_(self.p_smem_layout_store, (None, None, 0)) + tma_atom_p, tma_tensor_p = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + fstate, + p_smem_layout_store, + p_cta_v_layout, + ) + + # Compute grid size + tile_sched_params, grid = self._compute_grid(y, b, max_active_clusters) + + # Plan shared memory storage + swizzle_buffer_align_bytes = 1024 + nonswizzle_buffer_align_bytes = 128 + + @cute.struct + class SharedStorage: + # Input stage barriers + x_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + x_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + b_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + b_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + c_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + c_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + deltas_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + deltas_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + d_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + d_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + # Intra1 acc stage barriers + intra1_acc_full: cute.struct.MemRange[cutlass.Int64, self.intra1_acc_stages] # type: ignore + intra1_acc_empty: cute.struct.MemRange[cutlass.Int64, self.intra1_acc_stages] # type: ignore + # Internal stage barriers + intra2_q_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + intra2_q_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + intra2_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + intra2_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_b_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_b_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_p_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_p_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + # Tmem holding buffer + tmem_holding_buf: cutlass.Int32 + # Smem tensors + smem_x: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.x_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_b: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.b_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_bt_internal: cute.struct.Align[ + cute.struct.MemRange[ + self.io_dtype, cute.cosize(self.bt_internal_smem_layout) + ], + swizzle_buffer_align_bytes, + ] + smem_c: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.c_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_p: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.p_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_y: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.y_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_cumsum_delta: cute.struct.Align[ + cute.struct.MemRange[ + self.cumsum_delta_dtype, + cute.cosize(self.cumsum_delta_linear_smem_layout), + ], + nonswizzle_buffer_align_bytes, + ] + smem_delta: cute.struct.Align[ + cute.struct.MemRange[ + self.io_dtype, cute.cosize(self.delta_linear_smem_layout) + ], + nonswizzle_buffer_align_bytes, + ] + smem_d: cute.struct.Align[ + cute.struct.MemRange[ + self.io_dtype, + cute.cosize(self.d_linear_smem_layout) if self.d_has_hdim else 0, + ], + nonswizzle_buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + if cutlass.const_expr(self.shared_storage.size_in_bytes() > self.smem_capacity): + raise ValueError( + f"SharedStorage size {self.shared_storage.size_in_bytes()} exceeds smem_capacity {self.smem_capacity}" + ) + + # Launch the kernel synchronously + self.kernel( + tma_atom_x, + tma_tensor_x, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tma_atom_p, + tma_tensor_p, + tma_atom_y, + tma_tensor_y, + tma_atom_delta, + tma_tensor_delta, + tma_atom_cumsum_delta, + tma_tensor_cumsum_delta, + tma_atom_d, + tma_tensor_d, + self.cluster_layout_vmnk, + self.x_smem_layout, + self.xt_smem_layout, + self.b_smem_layout, + self.bt_smem_layout, + self.bt_internal_smem_layout, + self.c_smem_layout, + self.pt_smem_layout, + self.p_smem_layout, + self.q_tmem_layout, + self.p_smem_layout_store, + self.y_smem_layout, + self.delta_linear_smem_layout, + self.cumsum_delta_linear_smem_layout, + self.d_linear_smem_layout, + self.epi_tile, + tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + min_blocks_per_mp=1, + stream=stream, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + tma_atom_x: cute.CopyAtom, + tma_tensor_x: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + tma_atom_c: cute.CopyAtom, + tma_tensor_c: cute.Tensor, + tma_atom_p: cute.CopyAtom, + tma_tensor_p: cute.Tensor, + tma_atom_y: cute.CopyAtom, + tma_tensor_y: cute.Tensor, + tma_atom_delta: cute.CopyAtom, + tma_tensor_delta: cute.Tensor, + tma_atom_cumsum_delta: cute.CopyAtom, + tma_tensor_cumsum_delta: cute.Tensor, + tma_atom_d: Optional[cute.CopyAtom], + tma_tensor_d: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + x_smem_layout: cute.ComposedLayout, + xt_smem_layout: cute.ComposedLayout, + b_smem_layout: cute.ComposedLayout, + bt_smem_layout: cute.ComposedLayout, + bt_internal_smem_layout: cute.ComposedLayout, + c_smem_layout: cute.ComposedLayout, + pt_smem_layout: cute.ComposedLayout, + p_smem_layout: cute.ComposedLayout, + q_tmem_layout: cute.ComposedLayout, + p_smem_layout_store: cute.ComposedLayout, + y_smem_layout: cute.ComposedLayout, + delta_linear_smem_layout: cute.Layout, + cumsum_delta_linear_smem_layout: cute.Layout, + d_linear_smem_layout: Optional[cute.Layout], + epi_tile: cute.Tile, + tile_sched_params: Mamba2SSDTileSchedulerParams, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch tma descriptor + if warp_idx == 0: + tma_atoms = [ + tma_atom_x, + tma_atom_b, + tma_atom_c, + tma_atom_p, + tma_atom_y, + tma_atom_delta, + tma_atom_cumsum_delta, + ] + if cutlass.const_expr(self.d_has_hdim): + tma_atoms.append(tma_atom_d) + for tma_atom in tma_atoms: + cpasync.prefetch_descriptor(tma_atom) + + # Static consts + D = cute.size(tma_tensor_x, mode=[0]) + L = cute.size(tma_tensor_x, mode=[1]) + N = cute.size(tma_tensor_b, mode=[1]) + # Dynamic values + C = cute.size(tma_tensor_x, mode=[2]) + EH = cute.size(tma_tensor_x, mode=[3]) + B = cute.size(tma_tensor_x, mode=[4]) + G = cute.size(tma_tensor_b, mode=[3]) + NGROUP_RATIO = EH // G + + # Make TiledMma + ( + tiled_mma_intra1, + tiled_mma_intra2, + tiled_mma_inter1, + tiled_mma_inter2, + ) = self.make_tiled_mmas( + self.io_dtype, + self.acc_dtype, + self.cta_group, + self.tile_shape_mnk_intra1, + self.tile_shape_mnk_intra2, + self.tile_shape_mnk_inter1, + self.tile_shape_mnk_inter2, + ) + + # Setup cta/thread coordinates + # Block coord + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma_intra1.thr_id.shape) + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Workload coord + tile_sched = Mamba2SSDTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # Thread/warp coord + tidx, _, _ = cute.arch.thread_idx() + # Thread coord inside specialized warps + local_tidx = tidx % 128 + local_warp_idx = cute.arch.make_warp_uniform(local_tidx // 32) + + # Alloc and init smem tensors and pipelines + smem = utils.SmemAllocator() + smem_storage = smem.allocate(self.shared_storage) + + # Setup smem tensors + smem_x = smem_storage.smem_x.get_tensor( + x_smem_layout.outer, swizzle=x_smem_layout.inner + ) + smem_xt = smem_storage.smem_x.get_tensor( + xt_smem_layout.outer, swizzle=xt_smem_layout.inner + ) + smem_b = smem_storage.smem_b.get_tensor( + b_smem_layout.outer, swizzle=b_smem_layout.inner + ) + smem_bt = smem_storage.smem_b.get_tensor( + bt_smem_layout.outer, swizzle=bt_smem_layout.inner + ) + smem_bt_internal = smem_storage.smem_bt_internal.get_tensor( + bt_internal_smem_layout.outer, swizzle=bt_internal_smem_layout.inner + ) + smem_c = smem_storage.smem_c.get_tensor( + c_smem_layout.outer, swizzle=c_smem_layout.inner + ) + smem_p = smem_storage.smem_p.get_tensor( + p_smem_layout.outer, swizzle=p_smem_layout.inner + ) + smem_pt = smem_storage.smem_p.get_tensor( + pt_smem_layout.outer, swizzle=pt_smem_layout.inner + ) + smem_p_store = smem_storage.smem_p.get_tensor( + p_smem_layout_store.outer, swizzle=p_smem_layout_store.inner + ) + smem_y = smem_storage.smem_y.get_tensor( + y_smem_layout.outer, swizzle=y_smem_layout.inner + ) + smem_cumsum_delta = smem_storage.smem_cumsum_delta.get_tensor( + cumsum_delta_linear_smem_layout + ) + smem_delta = smem_storage.smem_delta.get_tensor(delta_linear_smem_layout) + smem_d = None + if cutlass.const_expr(self.d_has_hdim): + smem_d = smem_storage.smem_d.get_tensor(d_linear_smem_layout) + + # Init mbarrier for pipeline + x_pipeline = self.make_and_init_x_pipeline(smem_storage.x_full.data_ptr()) + b_pipeline = self.make_and_init_b_pipeline(smem_storage.b_full.data_ptr()) + c_pipeline = self.make_and_init_c_pipeline(smem_storage.c_full.data_ptr()) + deltas_pipeline = self.make_and_init_deltas_pipeline( + smem_storage.deltas_full.data_ptr() + ) + d_pipeline = self.make_and_init_d_pipeline(smem_storage.d_full.data_ptr()) + intra1_acc_pipeline = self.make_and_init_intra1_acc_pipeline( + smem_storage.intra1_acc_full.data_ptr() + ) + intra2_q_pipeline = self.make_and_init_intra2_q_pipeline( + smem_storage.intra2_q_full.data_ptr() + ) + intra2_acc_pipeline = self.make_and_init_intra2_acc_pipeline( + smem_storage.intra2_acc_full.data_ptr() + ) + inter1_b_pipeline = self.make_and_init_inter1_b_pipeline( + smem_storage.inter1_b_full.data_ptr() + ) + inter1_acc_pipeline = self.make_and_init_inter1_acc_pipeline( + smem_storage.inter1_acc_full.data_ptr() + ) + inter2_p_pipeline = self.make_and_init_inter2_p_pipeline( + smem_storage.inter2_p_full.data_ptr() + ) + inter2_acc_pipeline = self.make_and_init_inter2_acc_pipeline( + smem_storage.inter2_acc_full.data_ptr() + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_arrive_relaxed() + + # Cluster wait before tmem alloc + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_wait() + + # Alloc tmem buffer + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_cols_total, + smem_storage.tmem_holding_buf, + is_two_cta=self.use_2cta_instrs, + ) + + # Bar sync before retrieving tmem ptr from shared mem + cute.arch.barrier() + + # Retrieve tmem ptr + tmem_ptr_base = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=smem_storage.tmem_holding_buf, + ) + + # Specialized TMA load Delta/CumsumDelta/X warp + if warp_idx == self.tma_deltas_x_d_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, EH, B) + tXsX, tXgX_pre_slice = self.tma_partition_for_mma_b_operand( + tma_atom_x, + tma_tensor_x, + smem_x, + tiled_mma_intra2, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, C, EH, B) + tDeltasDelta, tDeltagDelta_pre_slice = self.tma_partition_with_shape( + tma_atom_delta, + tma_tensor_delta, + smem_delta, + (self.tile_shape_mnk_inter1[2],), + ) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, C, EH, B) + ( + tDeltasCumsumDelta, + tDeltagCumsumDelta_pre_slice, + ) = self.tma_partition_with_shape( + tma_atom_cumsum_delta, + tma_tensor_cumsum_delta, + smem_cumsum_delta, + (self.tile_shape_mnk_inter1[2],), + ) + + tDsD = None + tDgD_pre_slice = None + if cutlass.const_expr(self.d_has_hdim): + # Partition global/shared tensor for D + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, EH) + tDsD, tDgD_pre_slice = self.tma_partition_with_shape( + tma_atom_d, tma_tensor_d, smem_d, (self.tile_shape_mnk_inter2[1],) + ) + + # Pipeline X/Delta/CumsumDelta/D producer state + x_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + deltas_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + d_producer_state = None + if cutlass.const_expr(self.d_has_hdim): + # D is loaded by TMA only when d_has_hdim is True + d_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V), C) + tXgX = tXgX_pre_slice[None, 0, 0, None, eh_idx, b_idx] + tDeltagDelta = tDeltagDelta_pre_slice[None, 0, None, eh_idx, b_idx] + tDeltagCumsumDelta = tDeltagCumsumDelta_pre_slice[ + None, 0, None, eh_idx, b_idx + ] + tDgD = None + if cutlass.const_expr(self.d_has_hdim): + # ((ATOM_V, REST_V)) + tDgD = tDgD_pre_slice[None, 0, eh_idx] + + # Reset count for pipeline state + x_producer_state.reset_count() + deltas_producer_state.reset_count() + if cutlass.const_expr(self.d_has_hdim): + d_producer_state.reset_count() + + # Peek (try_wait) X/deltas buffer empty status + peek_x_empty_status = self.conditional_producer_try_acquire( + x_producer_state, x_pipeline, C + ) + peek_deltas_empty_status = self.conditional_producer_try_acquire( + deltas_producer_state, deltas_pipeline, C + ) + + if cutlass.const_expr(self.d_has_hdim): + # Wait for D buffer empty + d_pipeline.producer_acquire(d_producer_state) + # TMA load D + cute.copy( + tma_atom_d, + tDgD, + tDsD[None, d_producer_state.index], + tma_bar_ptr=d_pipeline.producer_get_barrier(d_producer_state), + ) + # Advance D producer state + d_producer_state.advance() + + # Batched load over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for X buffer empty + x_pipeline.producer_acquire(x_producer_state, peek_x_empty_status) + + # TMA load X + cute.copy( + tma_atom_x, + tXgX[None, x_producer_state.count], + tXsX[None, x_producer_state.index], + tma_bar_ptr=x_pipeline.producer_get_barrier(x_producer_state), + ) + + # Conditionally wait for deltas buffer empty + deltas_pipeline.producer_acquire( + deltas_producer_state, peek_deltas_empty_status + ) + + # TMA load Delta/CumsumDelta + cute.copy( + tma_atom_delta, + tDeltagDelta[None, deltas_producer_state.count], + tDeltasDelta[None, deltas_producer_state.index], + tma_bar_ptr=deltas_pipeline.producer_get_barrier( + deltas_producer_state + ), + ) + cute.copy( + tma_atom_cumsum_delta, + tDeltagCumsumDelta[None, deltas_producer_state.count], + tDeltasCumsumDelta[None, deltas_producer_state.index], + tma_bar_ptr=deltas_pipeline.producer_get_barrier( + deltas_producer_state + ), + ) + + # Advance X/deltas producer state + x_producer_state.advance() + deltas_producer_state.advance() + + # Peek (try_wait) X/deltas buffer empty status + peek_x_empty_status = self.conditional_producer_try_acquire( + x_producer_state, x_pipeline, C + ) + peek_deltas_empty_status = self.conditional_producer_try_acquire( + deltas_producer_state, deltas_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for X/Deltas/D + x_pipeline.producer_tail(x_producer_state) + deltas_pipeline.producer_tail(deltas_producer_state) + if cutlass.const_expr(self.d_has_hdim): + d_pipeline.producer_tail(d_producer_state) + # END of specialized tma load X/Deltas/D warp + + # Specialized TMA load B/C warp + elif warp_idx == self.tma_b_c_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, G, B) + tBsB, tBgB_pre_slice = self.tma_partition_for_mma_b_operand( + tma_atom_b, + tma_tensor_b, + smem_b, + tiled_mma_intra1, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, G, B) + tCsC, tCgC_pre_slice = self.tma_partition_for_mma_a_operand( + tma_atom_c, + tma_tensor_c, + smem_c, + tiled_mma_intra1, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ) + + # Pipeline B/C producer state + b_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + c_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V), C) + tBgB = tBgB_pre_slice[None, 0, 0, None, g_idx, b_idx] + tCgC = tCgC_pre_slice[None, 0, 0, None, g_idx, b_idx] + + # Reset count for pipeline state + b_producer_state.reset_count() + c_producer_state.reset_count() + + # Peek (try_wait) B/C buffer empty status + peek_b_empty_status = self.conditional_producer_try_acquire( + b_producer_state, b_pipeline, C + ) + peek_c_empty_status = self.conditional_producer_try_acquire( + c_producer_state, c_pipeline, C + ) + + # Batched load over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for B buffer empty + b_pipeline.producer_acquire(b_producer_state, peek_b_empty_status) + + # TMA load B + cute.copy( + tma_atom_b, + tBgB[None, b_producer_state.count], + tBsB[None, b_producer_state.index], + tma_bar_ptr=b_pipeline.producer_get_barrier(b_producer_state), + ) + + # Conditionally wait for C buffer empty + c_pipeline.producer_acquire(c_producer_state, peek_c_empty_status) + + # TMA load C + cute.copy( + tma_atom_c, + tCgC[None, c_producer_state.count], + tCsC[None, c_producer_state.index], + tma_bar_ptr=c_pipeline.producer_get_barrier(c_producer_state), + ) + + # Advance B/C producer state + b_producer_state.advance() + c_producer_state.advance() + + # Peek (try_wait) B/C buffer empty status + peek_b_empty_status = self.conditional_producer_try_acquire( + b_producer_state, b_pipeline, C + ) + peek_c_empty_status = self.conditional_producer_try_acquire( + c_producer_state, c_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for B/C + b_pipeline.producer_tail(b_producer_state) + c_pipeline.producer_tail(c_producer_state) + # END of specialized tma load B/C warp + + # Specialized MMA Intra warp + elif warp_idx == self.mma_intra_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # Make shared/tmem fragments for INTRA_MMA1 B/C/ACC + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_N, INTRA1_ACC_STAGE) + tCrC, tCrB, tCtAccIntra1 = self.mma_partition_ss( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + smem_c, + smem_b, + tmem_ptr_base + self.tmem_intra1_acc_offset, + self.intra1_acc_stages, + ) + + # Make shared/tmem fragments for INTRA_MMA2 X/Q/ACC + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCrQ, tCrX, tCtAccIntra2 = self.mma_partition_ts( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + q_tmem_layout, + smem_x, + tmem_ptr_base + self.tmem_intra2_q_offset, + tmem_ptr_base + self.tmem_intra2_acc_offset, + self.internal_stages, + ) + + # Pipeline B/C/X/INTRA2_Q consumer state + b_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + c_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + x_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + intra2_q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + + # Pipeline INTRA1_ACC/INTRA2_ACC producer state + intra1_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.intra1_acc_stages + ) + intra2_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + while work_tile.is_valid_tile: + # Reset count for pipeline state + b_consumer_state.reset_count() + c_consumer_state.reset_count() + intra1_acc_producer_state.reset_count() + x_consumer_state.reset_count() + intra2_q_consumer_state.reset_count() + intra2_acc_producer_state.reset_count() + + # Peek (try_wait) B/C/X/INTRA1_ACC buffer full/full/full/empty status + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_wr_intra1_acc_empty_status = self.conditional_producer_try_acquire( + intra1_acc_producer_state, intra1_acc_pipeline, C + ) + peek_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + # Manual pipeline: unrolled INTRA_MMA1 chunk_idx = 0 loop + # Conditionally wait for B/C/INTRA1_ACC buffer full/full/empty + b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status) + c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status) + intra1_acc_pipeline.producer_acquire( + intra1_acc_producer_state, peek_wr_intra1_acc_empty_status + ) + + # INTRA_MMA1 + tiled_mma_intra1 = self.exec_mma( + tiled_mma_intra1, + tCtAccIntra1, + tCrC, + tCrB, + intra1_acc_producer_state, + c_consumer_state, + b_consumer_state, + ) + + # Async arrive B/C/INTRA1_ACC buffer empty/empty/full + b_pipeline.consumer_release( + b_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + c_pipeline.consumer_release(c_consumer_state) + intra1_acc_pipeline.producer_commit(intra1_acc_producer_state) + + # Advance B/C/INTRA1_ACC state + b_consumer_state.advance() + c_consumer_state.advance() + intra1_acc_producer_state.advance() + + # Peek (try_wait) B/C/INTRA1_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1 + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_wr_intra1_acc_empty_status = self.conditional_producer_try_acquire( + intra1_acc_producer_state, intra1_acc_pipeline, C + ) + + # Manual pipeline: batched gemm over C-1 dimension + for chunk_idx in cutlass.range(C - 1, unroll=1): + # Conditionally wait for B/C/INTRA1_ACC buffer full/full/empty + b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status) + c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status) + intra1_acc_pipeline.producer_acquire( + intra1_acc_producer_state, peek_wr_intra1_acc_empty_status + ) + + # INTRA_MMA1 + tiled_mma_intra1 = self.exec_mma( + tiled_mma_intra1, + tCtAccIntra1, + tCrC, + tCrB, + intra1_acc_producer_state, + c_consumer_state, + b_consumer_state, + ) + + # Async arrive B/C/INTRA1_ACC buffer empty/empty/full + b_pipeline.consumer_release( + b_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + c_pipeline.consumer_release(c_consumer_state) + intra1_acc_pipeline.producer_commit(intra1_acc_producer_state) + + # Conditionally wait for X/INTRA2_Q/INTRA2_ACC buffer full/full/empty + x_pipeline.consumer_wait(x_consumer_state, peek_x_full_status) + intra2_q_pipeline.consumer_wait(intra2_q_consumer_state) + intra2_acc_pipeline.producer_acquire(intra2_acc_producer_state) + + # INTRA_MMA2 + tiled_mma_intra2 = self.exec_mma( + tiled_mma_intra2, + tCtAccIntra2, + tCrQ, + tCrX, + intra2_acc_producer_state, + intra2_q_consumer_state, + x_consumer_state, + ) + + # Async arrive X/INTRA2_Q/INTRA2_ACC buffer empty/empty/full + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + else: + x_pipeline.consumer_release(x_consumer_state) + intra2_q_pipeline.consumer_release(intra2_q_consumer_state) + intra2_acc_pipeline.producer_commit(intra2_acc_producer_state) + + # Advance B/C/INTRA1_ACC cstate + b_consumer_state.advance() + c_consumer_state.advance() + intra1_acc_producer_state.advance() + + # Peek (try_wait) B/C/INTRA1_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1 + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_wr_intra1_acc_empty_status = ( + self.conditional_producer_try_acquire( + intra1_acc_producer_state, intra1_acc_pipeline, C + ) + ) + + # Advance X/INTRA2_Q/INTRA2_ACC state + x_consumer_state.advance() + intra2_q_consumer_state.advance() + intra2_acc_producer_state.advance() + + # Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1 + peek_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C-1, unroll=1) + + # Manual pipeline: unrolled INTRA_MMA2 chunk_idx = C-1 loop + # Conditionally wait for X/INTRA2_Q/INTRA2_ACC buffer full/full/empty + x_pipeline.consumer_wait(x_consumer_state, peek_x_full_status) + intra2_q_pipeline.consumer_wait(intra2_q_consumer_state) + intra2_acc_pipeline.producer_acquire(intra2_acc_producer_state) + + # INTRA_MMA2 + tiled_mma_intra2 = self.exec_mma( + tiled_mma_intra2, + tCtAccIntra2, + tCrQ, + tCrX, + intra2_acc_producer_state, + intra2_q_consumer_state, + x_consumer_state, + ) + + # Async arrive X/INTRA2_Q/INTRA2_ACC buffer empty/empty/full + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + else: + x_pipeline.consumer_release(x_consumer_state) + intra2_q_pipeline.consumer_release(intra2_q_consumer_state) + intra2_acc_pipeline.producer_commit(intra2_acc_producer_state) + + # Advance X/INTRA2_Q/INTRA2_ACC state + x_consumer_state.advance() + intra2_q_consumer_state.advance() + intra2_acc_producer_state.advance() + + # Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1 + peek_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for INTRA1_ACC/INTRA2_ACC + intra1_acc_pipeline.producer_tail(intra1_acc_producer_state) + intra2_acc_pipeline.producer_tail(intra2_acc_producer_state) + # END of specialized mma-intra warp + + # Specialized MMA Inter warp + elif warp_idx == self.mma_inter_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # Make shared/tmem fragments for INTER_MMA1 X/B/ACC + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCrB, tCrX, tCtAccInter1 = self.mma_partition_ss( + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + smem_bt_internal, + smem_x, + tmem_ptr_base + self.tmem_inter1_acc_offset, + self.internal_stages, + ) + + # Make shared/tmem fragments for INTER_MMA2 C/P/ACC + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + # (MMA, MMA_N, MMA_K, INTERNAL_STAGE) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCrC, tCrP, tCtAccInter2 = self.mma_partition_ss( + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + smem_c, + smem_p, + tmem_ptr_base + self.tmem_inter2_acc_offset, + self.internal_stages, + ) + + # Pipeline X/C/INTER1_B/INTER2_P consumer state + x_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + c_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + inter1_b_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + inter2_p_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + + # Pipeline INTER1_ACC/INTER2_ACC producer state + inter1_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + inter2_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + while work_tile.is_valid_tile: + # Reset count for pipeline state + x_consumer_state.reset_count() + c_consumer_state.reset_count() + inter1_acc_producer_state.reset_count() + inter1_b_consumer_state.reset_count() + inter2_p_consumer_state.reset_count() + inter2_acc_producer_state.reset_count() + + # Peek (try_wait) C/INTER2_P/INTER2_ACC buffer full/full/empty status + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_inter2_p_full_status = self.conditional_consumer_try_wait( + inter2_p_consumer_state, inter2_p_pipeline, C + ) + peek_inter2_acc_empty_status = self.conditional_producer_try_acquire( + inter2_acc_producer_state, inter2_acc_pipeline, C + ) + + # Batched gemm over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for C/INTER2_P/INTER2_ACC buffer full/full/empty + c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status) + inter2_p_pipeline.consumer_wait( + inter2_p_consumer_state, peek_inter2_p_full_status + ) + inter2_acc_pipeline.producer_acquire( + inter2_acc_producer_state, peek_inter2_acc_empty_status + ) + + # INTER MMA2 + tiled_mma_inter2 = self.exec_mma( + tiled_mma_inter2, + tCtAccInter2, + tCrC, + tCrP, + inter2_acc_producer_state, + c_consumer_state, + inter2_p_consumer_state, + ) + + # Async arrive C/INTER2_P/INTER2_ACC buffer empty/empty/full + c_pipeline.consumer_release(c_consumer_state) + inter2_p_pipeline.consumer_release(inter2_p_consumer_state) + inter2_acc_pipeline.producer_commit(inter2_acc_producer_state) + + # Wait for X/INTER1_B/INTER1_ACC buffer full/full/empty + x_pipeline.consumer_wait(x_consumer_state) + inter1_b_pipeline.consumer_wait(inter1_b_consumer_state) + inter1_acc_pipeline.producer_acquire(inter1_acc_producer_state) + + # INTER MMA1 + tiled_mma_inter1 = self.exec_mma( + tiled_mma_inter1, + tCtAccInter1, + tCrB, + tCrX, + inter1_acc_producer_state, + inter1_b_consumer_state, + x_consumer_state, + ) + + # Async arrive X/INTER1_B/INTER1_ACC buffer empty/empty/full + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + else: + x_pipeline.consumer_release(x_consumer_state) + inter1_b_pipeline.consumer_release(inter1_b_consumer_state) + inter1_acc_pipeline.producer_commit(inter1_acc_producer_state) + + # Advance X/C/INTER1_B/INTER1_ACC/INTER2_P/INTER2_ACC state + x_consumer_state.advance() + c_consumer_state.advance() + inter1_b_consumer_state.advance() + inter1_acc_producer_state.advance() + inter2_p_consumer_state.advance() + inter2_acc_producer_state.advance() + + # Peek (try_wait) C/INTER2_P/INTER2_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1 + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_inter2_p_full_status = self.conditional_consumer_try_wait( + inter2_p_consumer_state, inter2_p_pipeline, C + ) + peek_inter2_acc_empty_status = ( + self.conditional_producer_try_acquire( + inter2_acc_producer_state, inter2_acc_pipeline, C + ) + ) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Producer tail for INTER1_ACC/INTER2_ACC + inter1_acc_pipeline.producer_tail(inter1_acc_producer_state) + inter2_acc_pipeline.producer_tail(inter2_acc_producer_state) + + # Specialized Pre-Inter warp + elif ( + warp_idx == self.pre_inter_warp_id[0] + or warp_idx == self.pre_inter_warp_id[1] + or warp_idx == self.pre_inter_warp_id[2] + or warp_idx == self.pre_inter_warp_id[3] + ): + # Alloc regs in pre_inter warps + cute.arch.warpgroup_reg_alloc(self.num_regs_pre_inter_warps) + + # Make tiledCopy and partition smem/register tensor for smem load Bt + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tiled_s2r_b, tBsB_s2r, tBrB_s2r = self.pre_inter_smem_load_and_partition_b( + local_tidx, smem_bt + ) + + # Partition shared tensor for smem store Bt + smem_bt_internal_ = cute.make_tensor( + smem_bt_internal.iterator, smem_bt.layout + ) + # Make tiledCopy and partition register/smem tensor for smem store Bt + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tiled_r2s_b, tBrB_r2s, tBsB_r2s = self.pre_inter_smem_store_and_partition_b( + local_tidx, + smem_bt_internal_, + tiled_s2r_b, + tBrB_s2r, + ) + + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + sDelta = self.pre_inter_make_delta(smem_delta, smem_bt.layout) + sDeltaA = self.pre_inter_make_delta(smem_cumsum_delta, smem_bt.layout) + + # Make copy_atom and partition register/smem tensor for smem load/store of Delta/DeltaA + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + ( + s2r_atom_delta, + tBsDelta_s2r, + tBrDelta_s2r, + ) = self.smem_load_and_partition_delta_d( + tiled_s2r_b, local_tidx, sDelta, (None, None, None, 0) + ) + ( + s2r_atom_cumsum, + tBsDeltaA_s2r, + tBrDeltaA_s2r, + ) = self.smem_load_and_partition_delta_d( + tiled_s2r_b, local_tidx, sDeltaA, (None, None, None, 0) + ) + + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + thr_r2s_b = tiled_r2s_b.get_slice(local_tidx) + tBrDelta_r2s = thr_r2s_b.retile(tBrDelta_s2r) + tBrDeltaA_r2s = thr_r2s_b.retile(tBrDeltaA_s2r) + + # Make tmem fragment for INTER1_ACC + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCtAccInter1 = self.mma_partition_c( + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + tmem_ptr_base + self.tmem_inter1_acc_offset, + self.internal_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE) + tInter1 = tCtAccInter1[((None, None), 0, 0, None)] + + # Make tiledCopy and partition tmem/register tensor for tmem load INTER1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INTERNAL_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + ( + tiled_t2r_inter1, + tTR_tP, + tTR_rP, + ) = self.pre_inter_tmem_load_and_partition_p(local_tidx, tInter1, smem_pt) + + # Make fragment for register to hold P after post-processing (in acc dtype) + tState = cute.make_fragment(tTR_rP.shape, self.acc_dtype) + + # Make tiledCopy and partition smem/register tensor for smem store INTER2_P + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tiled_r2s_p, tRS_rP, tRS_sP = self.smem_store_and_partition_p_y( + local_tidx, smem_pt, tiled_t2r_inter1 + ) + + # Partition global/shared tensor for P (State) + # ((ATOM_V, REST_V), INTERNAL_STAGE) + # ((ATOM_V, REST_V), 1, 1, EH, B) + bSG_sP, bSG_gP_pre_slice = self.tma_partition_with_shape( + tma_atom_p, + tma_tensor_p, + smem_p_store, + self.tile_shape_mnk_inter2[1:], + ) + + # Pipeline B/Delta/INTER1_ACC consumer state + b_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + deltas_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + inter1_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + + # Pipeline INTER1_B/INTER2_P producer state + inter1_b_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + inter2_p_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + # Pipeline TMA store P + tma_p_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.internal_stages, + producer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ), + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V)) + bSG_gP = bSG_gP_pre_slice[(None, 0, 0, eh_idx, b_idx)] + + # Reset count for pipeline state + b_consumer_state.reset_count() + deltas_consumer_state.reset_count() + inter1_b_producer_state.reset_count() + inter1_acc_consumer_state.reset_count() + inter2_p_producer_state.reset_count() + + # State (P) init + tState.fill(0.0) + + # Peek (try_wait) B/Delta/INTER1_B buffer full/full/empty status + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_wr_inter1_b_empty_status = self.conditional_producer_try_acquire( + inter1_b_producer_state, inter1_b_pipeline, C + ) + + # Prefill INTER2_P with 0 + # Wait for INTER2_P buffer empty + inter2_p_pipeline.producer_acquire(inter2_p_producer_state) + + tRS_rP.fill(0.0) + # Copy INTER2_P from register to smem + inter2_p_coord = (None, None, None, inter2_p_producer_state.index) + cute.copy(tiled_r2s_p, tRS_rP, tRS_sP[inter2_p_coord]) + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # Async arrive INTER2_P buffer full + inter2_p_pipeline.producer_commit(inter2_p_producer_state) + # Advance INTER2_P producer state + inter2_p_producer_state.advance() + + # Batched processing over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for B/Delta/B_TMEM buffer full/full/empty + b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status) + deltas_pipeline.consumer_wait( + deltas_consumer_state, peek_deltas_full_status + ) + inter1_b_pipeline.producer_acquire( + inter1_b_producer_state, peek_wr_inter1_b_empty_status + ) + + # Load B/Delta/DeltaA/last_column + b_coord = (None, None, None, b_consumer_state.index) + delta_coord = (None, None, None, deltas_consumer_state.index) + cute.copy(tiled_s2r_b, tBsB_s2r[b_coord], tBrB_s2r) + cute.copy(s2r_atom_delta, tBsDelta_s2r[delta_coord], tBrDelta_s2r) + cute.copy( + s2r_atom_cumsum, tBsDeltaA_s2r[delta_coord], tBrDeltaA_s2r + ) + last_column = smem_cumsum_delta[ + smem_cumsum_delta.shape[0] - 1, deltas_consumer_state.index + ] + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + # Combine B/Delta/DeltaA/last_column + tScaledB = self.pre_inter_scale_bt_with_delta( + tBrB_s2r, tBrDelta_r2s, tBrDeltaA_r2s, last_column + ) + + # Store scaled B to tBrB_r2s + for reg_idx in range(cute.size(tBrB_r2s)): + tBrB_r2s[reg_idx] = tScaledB[reg_idx].to(self.io_dtype) + + # Store tBrB_r2s to bt_smem_internal + inter1_b_coord = (None, None, None, inter1_b_producer_state.index) + cute.copy(tiled_r2s_b, tBrB_r2s, tBsB_r2s[inter1_b_coord]) + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + # Async arrive B/Delta/B_TMEM buffer empty/empty/full + b_pipeline.consumer_release( + b_consumer_state, pipeline.PipelineOp.AsyncThread + ) + deltas_pipeline.consumer_release(deltas_consumer_state) + inter1_b_pipeline.producer_commit(inter1_b_producer_state) + + # Wait for INTER1_ACC/INTER2_P buffer full/empty + inter1_acc_pipeline.consumer_wait(inter1_acc_consumer_state) + inter2_p_pipeline.producer_acquire(inter2_p_producer_state) + + # Load INTER1_ACC + inter1_acc_coord = ( + None, + None, + None, + inter1_acc_consumer_state.index, + ) + cute.copy(tiled_t2r_inter1, tTR_tP[inter1_acc_coord], tTR_rP) + + # Fence for TMEM load + cute.arch.fence_view_async_tmem_load() + + # Combine INTER1_ACC/last_column/State + exp_last_column = cute.math.exp(last_column, fastmath=True) + for reg_idx in range(0, cute.size(tTR_rP), 2): + ( + tTR_rP[reg_idx], + tTR_rP[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + (exp_last_column, exp_last_column), + (tState[reg_idx], tState[reg_idx + 1]), + (tTR_rP[reg_idx], tTR_rP[reg_idx + 1]), + ) + + # Store scaled P to tRS_rP + for reg_idx in range(cute.size(tTR_rP)): + tRS_rP[reg_idx] = tTR_rP[reg_idx].to(self.io_dtype) + + # Update old state + tState.store(tTR_rP.load()) + + # Store INTER2_P + inter2_p_coord = (None, None, None, inter2_p_producer_state.index) + cute.copy(tiled_r2s_p, tRS_rP, tRS_sP[inter2_p_coord]) + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + # Async arrive INTER1_ACC/INTER2_P buffer empty/full + inter1_acc_pipeline.consumer_release(inter1_acc_consumer_state) + # Last iteration consumer is PRE_INTER warp itself, not MMA_INTER warp + if inter2_p_producer_state.count < C: + inter2_p_pipeline.producer_commit(inter2_p_producer_state) + + # Advance B/Delta/INTER1_B/INTER1_ACC state + b_consumer_state.advance() + deltas_consumer_state.advance() + inter1_b_producer_state.advance() + inter1_acc_consumer_state.advance() + # Peek (try_wait) B/Delta/INTER1_B buffer full/full./empty for chunk_idx = chunk_idx + 1 + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_wr_inter1_b_empty_status = ( + self.conditional_producer_try_acquire( + inter1_b_producer_state, inter1_b_pipeline, C + ) + ) + + # Last iteration producer is PRE_INTER warp itself, not MMA_INTER warp + if inter2_p_producer_state.count < C: + # Advance INTER2_P producer state + inter2_p_producer_state.advance() + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Store last INTER2_P (State) from smem to gmem + # Wait for all previous stores to smem to be done + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.barrier( + barrier_id=self.pre_inter_sync_bar_id, + number_of_threads=len(self.pre_inter_warp_id) * 32, + ) + + if local_warp_idx == 0: + # TMA store P + cute.copy( + tma_atom_p, + bSG_sP[(None, inter2_p_producer_state.index)], + bSG_gP, + ) + # Wait for TMA store done + tma_p_pipeline.producer_commit() + tma_p_pipeline.producer_acquire() + + cute.arch.barrier( + barrier_id=self.pre_inter_sync_bar_id, + number_of_threads=len(self.pre_inter_warp_id) * 32, + ) + tma_p_pipeline.producer_tail() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for INTER1_B/INTER2_P/TMA store P + inter1_b_pipeline.producer_tail(inter1_b_producer_state) + inter2_p_pipeline.producer_tail(inter2_p_producer_state) + # END of specialized pre-inter warp + + # Specialized Pre-Intra warp + elif ( + warp_idx == self.pre_intra_warp_id[0] + or warp_idx == self.pre_intra_warp_id[1] + or warp_idx == self.pre_intra_warp_id[2] + or warp_idx == self.pre_intra_warp_id[3] + ): + # Alloc regs in pre_inter warps + cute.arch.warpgroup_reg_alloc(self.num_regs_pre_intra_warps) + + # Make tmem fragment for INTRA1_ACC + # (MMA, MMA_M, MMA_N, INTRA1_ACC_STAGE) + tCtAccIntra1 = self.mma_partition_c( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + tmem_ptr_base + self.tmem_intra1_acc_offset, + self.intra1_acc_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTRA1_ACC_STAGE) + tIntra1 = tCtAccIntra1[((None, None), 0, 0, None)] + + # Make tiledCopy and partition tmem/register tensor for tensor memory load INTRA1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INTERNAL_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tiled_t2r_intra1, tTR_tQ, tTR_rQ = self.pre_intra_tmem_load_and_partition_q( + tIntra1, local_tidx + ) + + # Broadcast delta/delta_cumsum smem tensor from LxINPUT_STAGE to LxLxINPUT_STAGE + sDeltaA_Row = self.pre_intra_make_delta(smem_cumsum_delta, 0) + sDeltaA_Col = self.pre_intra_make_delta(smem_cumsum_delta, 1) + sDelta = self.pre_intra_make_delta(smem_delta, 0) + + # Make tiledCopy and partition smem/register tensor for smem memory load delta/delta_cumsum + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INPUT_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + ( + s2r_atom_cumsum, + tQsDeltaA_Row, + tQrDeltaA_Row, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_intra1, local_tidx, sDeltaA_Row, (None, None, None, 0) + ) + ( + s2r_atom_cumsum, + tQsDeltaA_Col, + tQrDeltaA_Col, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_intra1, local_tidx, sDeltaA_Col, (None, None, None, 0) + ) + ( + s2r_atom_delta, + tQsDelta, + tQrDelta, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_intra1, local_tidx, sDelta, (None, None, None, 0) + ) + + # Make and partition coord tensor for delta_cumsum load + # (L, L) + coord_tensor = cute.make_identity_tensor( + cute.dice(self.tile_shape_mnk_intra1, (1, 1, None)) + ) + thr_t2r_intra1 = tiled_t2r_intra1.get_slice(local_tidx) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tCoord = thr_t2r_intra1.partition_D(coord_tensor) + + # Make tmem tensor for INTRA2_Q + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCrQ = self.mma_partition_a_tmem( + tiled_mma_intra2, + q_tmem_layout, + tmem_ptr_base + self.tmem_intra2_q_offset, + ) + + # Make tiledCopy and partition tmem/register tensor for tensor memory store INTRA2_Q + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ...) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ..., INTERNAL_STAGE) + tiled_r2t_q, tRT_rQ, tRT_tQ = self.pre_intra_tmem_store_and_partition_q( + local_tidx, tCrQ + ) + + # Pipeline DELTA/INTRA1_ACC consumer state + deltas_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + intra1_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.intra1_acc_stages + ) + # Pipeline INTRA2_Q producer state + intra2_q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + while work_tile.is_valid_tile: + # Reset count for pipeline state + deltas_consumer_state.reset_count() + intra1_acc_consumer_state.reset_count() + intra2_q_producer_state.reset_count() + + # Peek (try_wait) DELTA/INTRA1_ACC buffer full + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra1_acc_full_status = self.conditional_consumer_try_wait( + intra1_acc_consumer_state, intra1_acc_pipeline, C + ) + + # Batched processing over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for Delta/INTRA1_ACC buffer full + deltas_pipeline.consumer_wait( + deltas_consumer_state, peek_deltas_full_status + ) + intra1_acc_pipeline.consumer_wait( + intra1_acc_consumer_state, peek_rd_intra1_acc_full_status + ) + + # Load Q from tmem + intra1_coord = (None, None, None, intra1_acc_consumer_state.index) + cute.copy(tiled_t2r_intra1, tTR_tQ[intra1_coord], tTR_rQ) + cute.arch.fence_view_async_tmem_load() + + # Load tQsDeltaA_Row/tQsDeltaA_Col/tQsDelta from smem + delta_coord = (None, None, None, deltas_consumer_state.index) + cute.copy( + s2r_atom_cumsum, tQsDeltaA_Row[delta_coord], tQrDeltaA_Row + ) + cute.copy( + s2r_atom_cumsum, tQsDeltaA_Col[delta_coord], tQrDeltaA_Col + ) + cute.copy(s2r_atom_delta, tQsDelta[delta_coord], tQrDelta) + + # SegSum + tRT_rQ = self.pre_intra_segsum( + tTR_rQ, tQrDeltaA_Row, tQrDeltaA_Col, tQrDelta, tCoord, tRT_rQ + ) + + # Wait for INTRA2_Q buffer empty + # Delay producer_acquire to right before data store + intra2_q_pipeline.producer_acquire(intra2_q_producer_state) + + # Store Q from reg to tmem + q_coord = (None, None, None, None, intra2_q_producer_state.index) + cute.copy(tiled_r2t_q, tRT_rQ, tRT_tQ[q_coord]) + + # Async arrive Delta/INTRA1_ACC buffer empty + intra1_acc_pipeline.consumer_release(intra1_acc_consumer_state) + deltas_pipeline.consumer_release(deltas_consumer_state) + + cute.arch.fence_view_async_tmem_store() + + # Async arrive INTRA2_Q buffer full + intra2_q_pipeline.producer_commit(intra2_q_producer_state) + + # Advance deltas/intra1_acc/intra2_q states + deltas_consumer_state.advance() + intra1_acc_consumer_state.advance() + intra2_q_producer_state.advance() + + # Peek (try_wait) Delta/INTRA1_ACC buffer full for chunk_idx = chunk_idx + 1 + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra1_acc_full_status = self.conditional_consumer_try_wait( + intra1_acc_consumer_state, intra1_acc_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for INTRA2_Q + intra2_q_pipeline.producer_tail(intra2_q_producer_state) + # END of specialized pre-intra warp + + # Specialized Epilogue warp + else: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_epilogue_warps) + + # (L, D, INPUT_STAGE) + sDeltaA = self.epilog_make_delta(smem_cumsum_delta) + + # Make tmem tensor for INTRA2_ACC/INTER2_ACC + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCtAccIntra2 = self.mma_partition_c( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + tmem_ptr_base + self.tmem_intra2_acc_offset, + self.internal_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE) + tIntra2 = tCtAccIntra2[((None, None), 0, 0, None)] + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCtAccInter2 = self.mma_partition_c( + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + tmem_ptr_base + self.tmem_inter2_acc_offset, + self.internal_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE) + tInter2 = tCtAccInter2[((None, None), 0, 0, None)] + + # Subtiling INTRA2_ACC/INTER2_ACC/Delta/Y + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INTERNAL_STAGE) + tIntra_epi = cute.flat_divide(tIntra2, epi_tile) + tInter_epi = cute.flat_divide(tInter2, epi_tile) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INPUT_STAGE) + sDeltaA_epi = cute.flat_divide(sDeltaA, epi_tile) + + # Make tiled copy and partition tmem/reg tensor w.r.t tensor memory load + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N, EPI_M, EPI_N, INTERNAL_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N) + ( + tiled_t2r_intra2, + tTR_tIntra, + tTR_rIntra, + ) = self.epilog_tmem_load_and_partition_acc(local_tidx, tIntra_epi, smem_y) + ( + tiled_t2r_inter2, + tTR_tInter2, + tTR_rInter, + ) = self.epilog_tmem_load_and_partition_acc(local_tidx, tInter_epi, smem_y) + + # Make tiled copy and partition smem/reg tensor w.r.t smem load Delta + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, EPI_M, EPI_N, INPUT_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + ( + s2r_atom_delta, + tTR_sDeltaA, + tTR_rDeltaA, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_inter2, local_tidx, sDeltaA_epi, (None, None, None, 0, 0, 0) + ) + + # Make tiled copy and Partition smem/register tensor w.r.t smem store Y + # ((R2S_ATOM_V, R2S_REST_V), REST_M, REST_N, OUTPUT_STAGE) + # ((R2S_ATOM_V, R2S_REST_V), REST_M, REST_N) + tiled_r2s_y, tRS_rY, tRS_sY = self.smem_store_and_partition_p_y( + local_tidx, smem_y, tiled_t2r_inter2 + ) + + tRS_rCompute = cute.make_fragment(tRS_rY.shape, self.acc_dtype) + + tiled_s2r_x = None + tSR_sX = None + tSR_rX = None + if cutlass.const_expr(self.has_d): + # Make TiledCopy/smem/register tensor for smem load X + # (R2S_ATOM, R2S_M, R2S_N, EPI_M, EPI_N, INPUT_STAGES) + # (R2S_ATOM, R2S_M, R2S_N) + tiled_s2r_x, tSR_sX, tSR_rX = self.epilog_smem_load_and_partition_x( + tiled_t2r_inter2, local_tidx, smem_xt, epi_tile + ) + + tRS_sD = None + tRS_rD = None + s2r_atom_d = None + if cutlass.const_expr(self.d_has_hdim): + # (L, D, INPUT_STAGE) + sD = self.epilog_make_d(smem_d) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INPUT_STAGE) + tD_sepi = cute.flat_divide(sD, epi_tile) + + # Make tiled copy and partition smem/reg tensor w.r.t smem load D + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N, EPI_M, EPI_N, INPUT_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N) + s2r_atom_d, tRS_sD, tRS_rD = self.smem_load_and_partition_delta_d( + tiled_t2r_inter2, local_tidx, tD_sepi, (None, None, None, 0, 0, 0) + ) + + elif cutlass.const_expr(self.has_d): + tRS_rD = cutlass.Float32(0.0).to(self.io_dtype) + + # Partition global/shared tensor for TMA store Y + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), EPI_M, EPI_N, 1, 1, C, EH, B) + bSG_sY, bSG_gY_pre_slice = self.epilog_tma_partition_y( + tma_tensor_y, tma_atom_y, smem_y, epi_tile + ) + + # Make TMA store pipeline Y + tma_y_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.output_stages, + producer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ), + ) + + # Make consumer pipeline states for Delta/INTRA2_ACC/INTER2_ACC/X/D buffer + deltas_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + intra2_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + inter2_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + x_consumer_state = None + if cutlass.const_expr(self.has_d): + x_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + d_consumer_state = None + if cutlass.const_expr(self.d_has_hdim): + d_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V), EPI_M, EPI_N, C) + bSG_gY = bSG_gY_pre_slice[(None, None, None, 0, 0, None, eh_idx, b_idx)] + if cutlass.const_expr(self.has_d and not self.d_has_hdim): + tRS_rD = tma_tensor_d[0, eh_idx] + + # Reset count for pipeline state + deltas_consumer_state.reset_count() + intra2_acc_consumer_state.reset_count() + inter2_acc_consumer_state.reset_count() + if cutlass.const_expr(self.has_d): + x_consumer_state.reset_count() + if cutlass.const_expr(self.d_has_hdim): + d_consumer_state.reset_count() + + # Peek Delta/INTRA2_ACC/INTER2_ACC buffer status + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra2_acc_full_status = self.conditional_consumer_try_wait( + intra2_acc_consumer_state, intra2_acc_pipeline, C + ) + peek_rd_inter2_acc_full_status = self.conditional_consumer_try_wait( + inter2_acc_consumer_state, inter2_acc_pipeline, C + ) + peek_rd_x_full_status = None + if cutlass.const_expr(self.has_d): + peek_rd_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + if cutlass.const_expr(self.d_has_hdim): + d_pipeline.consumer_wait(d_consumer_state) + + # Batched processing over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for Delta/INTRA2_ACC/INTER2_ACC/X buffer full + deltas_pipeline.consumer_wait( + deltas_consumer_state, peek_deltas_full_status + ) + intra2_acc_pipeline.consumer_wait( + intra2_acc_consumer_state, peek_rd_intra2_acc_full_status + ) + inter2_acc_pipeline.consumer_wait( + inter2_acc_consumer_state, peek_rd_inter2_acc_full_status + ) + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_wait( + x_consumer_state, peek_rd_x_full_status + ) + # Loop over EPI_M and EPI_N subtiles + for epi_n in range(cute.size(tTR_tIntra, mode=[4])): + for epi_m in range(cute.size(tTR_tIntra, mode=[3])): + epi_iter_cnt = ( + epi_n * cute.size(tTR_tIntra, mode=[3]) + epi_m + ) + epi_buffer_idx = epi_iter_cnt % self.output_stages + + # Load INTRA2_ACC/INTER2_ACC from tmem + subtile_coord = ( + None, + None, + None, + epi_m, + epi_n, + ) + intra2_coord = subtile_coord + ( + intra2_acc_consumer_state.index, + ) + cute.copy( + tiled_t2r_intra2, + tTR_tIntra[intra2_coord], + tTR_rIntra, + ) + inter2_coord = subtile_coord + ( + inter2_acc_consumer_state.index, + ) + cute.copy( + tiled_t2r_inter2, + tTR_tInter2[inter2_coord], + tTR_rInter, + ) + # Fence for T2R load + cute.arch.fence_view_async_tmem_load() + + # Load Delta from smem + delta_coord = subtile_coord + (deltas_consumer_state.index,) + cute.copy( + s2r_atom_delta, tTR_sDeltaA[delta_coord], tTR_rDeltaA + ) + + # Load X from smem + if cutlass.const_expr(self.has_d): + x_coord = subtile_coord + (x_consumer_state.index,) + cute.copy(tiled_s2r_x, tSR_sX[x_coord], tSR_rX) + + # Load D from smem + if cutlass.const_expr(self.d_has_hdim): + # Load vector D from smem (d_has_hdim = True) + d_coord = subtile_coord + (d_consumer_state.index,) + cute.copy(s2r_atom_d, tRS_sD[d_coord], tRS_rD) + + # Combine INTRA2_ACC/INTER2_ACC/Delta/X/D + for reg_idx in range(0, cute.size(tRS_rCompute), 2): + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + (tTR_rInter[reg_idx], tTR_rInter[reg_idx + 1]), + ( + cute.math.exp( + tTR_rDeltaA[reg_idx], fastmath=True + ), + cute.math.exp( + tTR_rDeltaA[reg_idx + 1], fastmath=True + ), + ), + (tTR_rIntra[reg_idx], tTR_rIntra[reg_idx + 1]), + ) + # Fuse Y += X * D + if cutlass.const_expr(self.d_has_hdim): + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + ( + tRS_rD[reg_idx].to(self.acc_dtype), + tRS_rD[reg_idx + 1].to(self.acc_dtype), + ), + ( + tSR_rX[reg_idx].to(self.acc_dtype), + tSR_rX[reg_idx + 1].to(self.acc_dtype), + ), + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ), + ) + elif cutlass.const_expr(self.has_d): + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + ( + tRS_rD.to(self.acc_dtype), + tRS_rD.to(self.acc_dtype), + ), + ( + tSR_rX[reg_idx].to(self.acc_dtype), + tSR_rX[reg_idx + 1].to(self.acc_dtype), + ), + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ), + ) + + tRS_rY.store(tRS_rCompute.load().to(self.io_dtype)) + + # Store Y to smem + cute.copy( + tiled_r2s_y, + tRS_rY, + tRS_sY[None, None, None, epi_buffer_idx], + ) + + # Fence for R2S store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # Sync before TMA store + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=len(self.epilog_warp_id) * 32, + ) + + # Async arrive Delta/INTRA2_ACC/INTER2_ACC buffer empty + if ( + epi_iter_cnt + == cute.size(tTR_tIntra, mode=[4]) + * cute.size(tTR_tIntra, mode=[3]) + - 1 + ): + deltas_pipeline.consumer_release(deltas_consumer_state) + intra2_acc_pipeline.consumer_release( + intra2_acc_consumer_state + ) + inter2_acc_pipeline.consumer_release( + inter2_acc_consumer_state + ) + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, + pipeline.PipelineOp.AsyncThread, + ) + + # TMA store Y to global memory + if local_warp_idx == 0: + cute.copy( + tma_atom_y, + bSG_sY[None, epi_buffer_idx], + bSG_gY[None, epi_m, epi_n, chunk_idx], + ) + + # Commit TMA store + tma_y_pipeline.producer_commit() + # Wait for TMA store + tma_y_pipeline.producer_acquire() + # Sync before smem store + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=len(self.epilog_warp_id) * 32, + ) + + # Advance deltas/intra2_acc/inter2_acc consumer states + deltas_consumer_state.advance() + intra2_acc_consumer_state.advance() + inter2_acc_consumer_state.advance() + + # Peek (try_wait) Delta/INTRA2_ACC/INTER2_ACC buffer full for chunk_idx = chunk_idx + 1 + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra2_acc_full_status = self.conditional_consumer_try_wait( + intra2_acc_consumer_state, intra2_acc_pipeline, C + ) + peek_rd_inter2_acc_full_status = self.conditional_consumer_try_wait( + inter2_acc_consumer_state, inter2_acc_pipeline, C + ) + + if cutlass.const_expr(self.has_d): + # Advance x consumer states + x_consumer_state.advance() + # Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1 + peek_rd_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + if cutlass.const_expr(self.d_has_hdim): + d_pipeline.consumer_release(d_consumer_state) + d_consumer_state.advance() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Producer tail for TMA store Y + tma_y_pipeline.producer_tail() + + # Dealloc tmem buffer + if warp_idx == self.epilog_warp_id[0]: + cute.arch.barrier( + barrier_id=self.tmem_dealloc_sync_bar_id, + number_of_threads=self.threads_per_cta, + ) + cute.arch.dealloc_tmem( + tmem_ptr_base, + self.num_tmem_cols_total, + is_two_cta=self.use_2cta_instrs, + ) + else: + cute.arch.barrier_arrive( + barrier_id=self.tmem_dealloc_sync_bar_id, + number_of_threads=self.threads_per_cta, + ) + + return + + @staticmethod + def _compute_stages(smem_capacity): + return 2, 2, 1, 2 # input, output, internal, intra1_acc + + @staticmethod + def _compute_grid(y, b, max_active_clusters): + B = cute.size(y, mode=[4]) + EH = cute.size(y, mode=[3]) + G = cute.size(b, mode=[3]) + NGROUP_RATIO = EH // G + num_blocks = B * EH + + tile_sched_params = Mamba2SSDTileSchedulerParams(num_blocks, EH, NGROUP_RATIO) + grid = Mamba2SSDTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + return tile_sched_params, grid + + @staticmethod + def _plan_tmem_offsets( + tiled_mma_intra1, + tile_shape_mnk_intra1, + tiled_mma_intra2, + tile_shape_mnk_intra2, + tiled_mma_inter1, + tile_shape_mnk_inter1, + tiled_mma_inter2, + tile_shape_mnk_inter2, + acc_stages, + intra2_a_tmem_layout, + a_dtype, + internal_stages, + intra1_acc_stages, + ): + SM100_TMEM_CAPACITY_COLUMNS = 512 + BITS_PER_TMEM_COL = 32 + # (MMA, MMA_M, MMA_N) + acc_shape_intra1 = tiled_mma_intra1.partition_shape_C(tile_shape_mnk_intra1[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccIntra1_fake = tiled_mma_intra1.make_fragment_C( + cute.append(acc_shape_intra1, intra1_acc_stages) + ) + num_intra1_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccIntra1_fake) + assert tile_shape_mnk_intra1[1] * intra1_acc_stages == num_intra1_acc_cols + # (MMA, MMA_N, MMA_K, STAGE) + tCrQ_fake = tiled_mma_intra2.make_fragment_A(intra2_a_tmem_layout.outer.shape) + num_intra2_a_cols = tcgen05.find_tmem_tensor_col_offset(tCrQ_fake) + assert ( + tile_shape_mnk_intra2[2] + * internal_stages + * a_dtype.width + // BITS_PER_TMEM_COL + == num_intra2_a_cols + ) + # (MMA, MMA_M, MMA_N) + acc_shape_intra2 = tiled_mma_intra2.partition_shape_C(tile_shape_mnk_intra2[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccIntra2_fake = tiled_mma_intra2.make_fragment_C( + cute.append(acc_shape_intra2, acc_stages) + ) + num_intra2_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccIntra2_fake) + assert tile_shape_mnk_intra2[1] * acc_stages == num_intra2_acc_cols + + # (MMA, MMA_M, MMA_N) + acc_shape_inter1 = tiled_mma_inter1.partition_shape_C(tile_shape_mnk_inter1[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccInter1_fake = tiled_mma_inter1.make_fragment_C( + cute.append(acc_shape_inter1, acc_stages) + ) + num_inter1_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccInter1_fake) + assert tile_shape_mnk_inter1[1] * acc_stages == num_inter1_acc_cols + + # (MMA, MMA_M, MMA_N) + acc_shape_inter2 = tiled_mma_inter2.partition_shape_C(tile_shape_mnk_inter2[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccInter2_fake = tiled_mma_inter2.make_fragment_C( + cute.append(acc_shape_inter2, acc_stages) + ) + num_inter2_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccInter2_fake) + assert tile_shape_mnk_inter2[1] * acc_stages == num_inter2_acc_cols + + tmem_intra1_acc_offset = 0 + tmem_intra2_q_offset = tmem_intra1_acc_offset + num_intra1_acc_cols + tmem_intra2_acc_offset = tmem_intra2_q_offset + num_intra2_a_cols + tmem_inter1_acc_offset = tmem_intra2_acc_offset + num_intra2_acc_cols + tmem_inter2_acc_offset = tmem_inter1_acc_offset + num_inter1_acc_cols + num_tmem_cols_total_tmp = tmem_inter2_acc_offset + num_inter2_acc_cols + # Turn num_tmem_cols_total to the nearest power of 2 + num_tmem_cols_total = 1 + while num_tmem_cols_total < num_tmem_cols_total_tmp: + num_tmem_cols_total *= 2 + assert num_tmem_cols_total <= SM100_TMEM_CAPACITY_COLUMNS + + return ( + tmem_intra1_acc_offset, + tmem_intra2_q_offset, + tmem_intra2_acc_offset, + tmem_inter1_acc_offset, + tmem_inter2_acc_offset, + num_tmem_cols_total, + ) + + @staticmethod + def make_tiled_mmas( + io_dtype, + acc_dtype, + cta_group, + tile_shape_mnk_intra1, + tile_shape_mnk_intra2, + tile_shape_mnk_inter1, + tile_shape_mnk_inter2, + ): + tiled_mma_intra1 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("mn"), + tcgen05.OperandMajorMode("mn"), + acc_dtype, + cta_group, + tile_shape_mnk_intra1[:2], + tcgen05.OperandSource.SMEM, + ) + tiled_mma_intra2 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("k"), + tcgen05.OperandMajorMode("k"), + acc_dtype, + cta_group, + tile_shape_mnk_intra2[:2], + tcgen05.OperandSource.TMEM, + ) + tiled_mma_inter1 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("k"), + tcgen05.OperandMajorMode("k"), + acc_dtype, + cta_group, + tile_shape_mnk_inter1[:2], + tcgen05.OperandSource.SMEM, + ) + tiled_mma_inter2 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("mn"), + tcgen05.OperandMajorMode("k"), + acc_dtype, + cta_group, + tile_shape_mnk_inter2[:2], + tcgen05.OperandSource.SMEM, + ) + return tiled_mma_intra1, tiled_mma_intra2, tiled_mma_inter1, tiled_mma_inter2 + + def make_and_init_x_pipeline(self, x_full_mbar_ptr): + x_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id]) + ) + if not self.has_d: + x_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_intra_warp_id, self.mma_inter_warp_id]), + ) + return pipeline.PipelineTmaUmma.create( + num_stages=self.input_stages, + producer_group=x_producer_group, + consumer_group=x_consumer_group, + tx_count=self.num_x_load_bytes, + barrier_storage=x_full_mbar_ptr, + ) + else: + x_consumer_group_umma = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_intra_warp_id, self.mma_inter_warp_id]), + ) + x_consumer_group_async = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ) + return pipeline.PipelineTmaMultiConsumersAsync.create( + num_stages=self.input_stages, + producer_group=x_producer_group, + consumer_group_umma=x_consumer_group_umma, + consumer_group_async=x_consumer_group_async, + tx_count=self.num_x_load_bytes, + barrier_storage=x_full_mbar_ptr, + ) + + def make_and_init_b_pipeline(self, b_full_mbar_ptr): + b_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_b_c_warp_id]) + ) + b_consumer_group_umma = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + b_consumer_group_async = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + return pipeline.PipelineTmaMultiConsumersAsync.create( + num_stages=self.input_stages, + producer_group=b_producer_group, + consumer_group_umma=b_consumer_group_umma, + consumer_group_async=b_consumer_group_async, + tx_count=self.num_b_load_bytes, + barrier_storage=b_full_mbar_ptr, + ) + + def make_and_init_c_pipeline(self, c_full_mbar_ptr): + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_b_c_warp_id]) + ) + c_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id, self.mma_inter_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + num_stages=self.input_stages, + producer_group=c_producer_group, + consumer_group=c_consumer_group, + tx_count=self.num_c_load_bytes, + barrier_storage=c_full_mbar_ptr, + ) + + def make_and_init_deltas_pipeline(self, deltas_full_mbar_ptr): + deltas_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id]) + ) + deltas_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len( + [*self.pre_inter_warp_id, *self.pre_intra_warp_id, *self.epilog_warp_id] + ), + len( + [*self.pre_inter_warp_id, *self.pre_intra_warp_id, *self.epilog_warp_id] + ), + ) + + return pipeline.PipelineTmaAsync.create( + num_stages=self.input_stages, + producer_group=deltas_producer_group, + consumer_group=deltas_consumer_group, + tx_count=self.num_delta_load_bytes + self.num_cumsum_delta_load_bytes, + barrier_storage=deltas_full_mbar_ptr, + ) + + def make_and_init_d_pipeline(self, d_full_mbar_ptr): + if not self.d_has_hdim: + return None + else: + d_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id]) + ) + d_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len(self.epilog_warp_id), + len(self.epilog_warp_id), + ) + + return pipeline.PipelineTmaAsync.create( + num_stages=self.input_stages, + producer_group=d_producer_group, + consumer_group=d_consumer_group, + tx_count=self.num_d_load_bytes, + barrier_storage=d_full_mbar_ptr, + ) + + def make_and_init_intra1_acc_pipeline(self, intra1_acc_full_mbar_ptr): + intra1_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + intra1_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.intra1_acc_stages, + producer_group=intra1_acc_producer_group, + consumer_group=intra1_acc_consumer_group, + barrier_storage=intra1_acc_full_mbar_ptr, + ) + + def make_and_init_intra2_q_pipeline(self, intra2_q_full_mbar_ptr): + intra2_q_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id), 128 + ) + intra2_q_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + num_stages=self.internal_stages, + producer_group=intra2_q_producer_group, + consumer_group=intra2_q_consumer_group, + barrier_storage=intra2_q_full_mbar_ptr, + ) + + def make_and_init_intra2_acc_pipeline(self, intra2_acc_full_mbar_ptr): + intra2_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + intra2_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.internal_stages, + producer_group=intra2_acc_producer_group, + consumer_group=intra2_acc_consumer_group, + barrier_storage=intra2_acc_full_mbar_ptr, + ) + + def make_and_init_inter1_b_pipeline(self, inter1_b_full_mbar_ptr): + inter1_b_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + inter1_b_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + num_stages=self.internal_stages, + producer_group=inter1_b_producer_group, + consumer_group=inter1_b_consumer_group, + barrier_storage=inter1_b_full_mbar_ptr, + ) + + def make_and_init_inter1_acc_pipeline(self, inter1_acc_full_mbar_ptr): + inter1_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + inter1_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.internal_stages, + producer_group=inter1_acc_producer_group, + consumer_group=inter1_acc_consumer_group, + barrier_storage=inter1_acc_full_mbar_ptr, + ) + + def make_and_init_inter2_p_pipeline(self, inter2_p_full_mbar_ptr): + inter2_p_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + inter2_p_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + num_stages=self.internal_stages, + producer_group=inter2_p_producer_group, + consumer_group=inter2_p_consumer_group, + barrier_storage=inter2_p_full_mbar_ptr, + ) + + def make_and_init_inter2_acc_pipeline(self, inter2_acc_full_mbar_ptr): + inter2_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + inter2_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.internal_stages, + producer_group=inter2_acc_producer_group, + consumer_group=inter2_acc_consumer_group, + barrier_storage=inter2_acc_full_mbar_ptr, + ) + + def tma_partition_for_mma_b_operand( + self, + tma_atom_x, + tma_tensor_x, + smem_x, + tiled_mma_intra2, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ): + # Local_tile partition global tensors + # (D, L, 1, 1, C, EH, B) + gX = cute.local_tile( + tma_tensor_x, + self.tile_shape_mnk_intra2[1:], + (None, None, None, None, None), + ) + # Partition global tensor with regard to TiledMMA + thr_mma_intra2 = tiled_mma_intra2.get_slice(mma_tile_coord_v) + # (MMA, MMA_N, MMA_K, 1, 1, C, EH, B) + tCgX = thr_mma_intra2.partition_B(gX) + + # Partition global/shared tensor for X + x_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, EH, B) + tXsX, tXgX_pre_slice = cpasync.tma_partition( + tma_atom_x, + block_in_cluster_coord_vmnk[2], + x_cta_layout, + cute.group_modes(smem_x, 0, 3), + cute.group_modes(tCgX, 0, 3), + ) + return tXsX, tXgX_pre_slice + + def tma_partition_for_mma_a_operand( + self, + tma_atom_c, + tma_tensor_c, + smem_c, + tiled_mma_intra1, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ): + # Local_tile partition global tensors + # (L, N, 1, 1, C, G, B) + gC = cute.local_tile( + tma_tensor_c, + cute.slice_(self.tile_shape_mnk_intra1, (None, 0, None)), + (None, None, None, None, None), + ) + # Partition global tensor with regard to TiledMMA + thr_mma_intra1 = tiled_mma_intra1.get_slice(mma_tile_coord_v) + # (MMA, MMA_M/N, MMA_K, 1, 1, C, G, B) + tCgC = thr_mma_intra1.partition_A(gC) + + # Partition global/shared tensor for TMA C + c_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, G, B) + tCsC, tCgC_pre_slice = cpasync.tma_partition( + tma_atom_c, + block_in_cluster_coord_vmnk[1], + c_cta_layout, + cute.group_modes(smem_c, 0, 3), + cute.group_modes(tCgC, 0, 3), + ) + return tCsC, tCgC_pre_slice + + def tma_partition_with_shape( + self, tma_atom_delta, tma_tensor_delta, smem_delta, shape + ): + # Local_tile partition global tensors + # (L, 1, C, EH, B) + gDelta = cute.local_tile( + tma_tensor_delta, + shape, + (None,) * cute.rank(tma_tensor_delta), + ) + # Partition global/shared tensor for DELTA + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, C, EH, B) + tDeltasDelta, tDeltagDelta_pre_slice = cpasync.tma_partition( + tma_atom_delta, + 0, + cute.make_layout(1), + cute.group_modes(smem_delta, 0, cute.rank(shape)), + cute.group_modes(gDelta, 0, cute.rank(shape)), + ) + + return tDeltasDelta, tDeltagDelta_pre_slice + + def mma_partition_ss( + self, + tiled_mma, + tile_shape_mnk, + smem_a, + smem_b, + tmem_acc_ptr, + acc_stages, + ): + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + tCrA = tiled_mma.make_fragment_A(smem_a) + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + tCrB = tiled_mma.make_fragment_B(smem_b) + # (MMA, MMA_M, MMA_N, ACC_STAGE) + tCtAcc = self.mma_partition_c( + tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages + ) + return tCrA, tCrB, tCtAcc + + def mma_partition_ts( + self, + tiled_mma, + tile_shape_mnk, + a_tmem_layout, + smem_b, + tmem_a_ptr, + tmem_acc_ptr, + acc_stages, + ): + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCrA = self.mma_partition_a_tmem(tiled_mma, a_tmem_layout, tmem_a_ptr) + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + tCrB = tiled_mma.make_fragment_B(smem_b) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCtAcc = self.mma_partition_c( + tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages + ) + return tCrA, tCrB, tCtAcc + + def mma_partition_a_tmem(self, tiled_mma, a_tmem_layout, tmem_a_ptr): + tCrA_fake = tiled_mma.make_fragment_A(a_tmem_layout.outer.shape) + tCrA = cute.make_tensor( + cute.recast_ptr( + tmem_a_ptr, + dtype=tCrA_fake.element_type, + ), + tCrA_fake.layout, + ) + return tCrA + + def mma_partition_c(self, tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages): + acc_shape = tiled_mma.partition_shape_C(tile_shape_mnk[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, acc_stages)) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCtAcc = cute.make_tensor(tmem_acc_ptr, tCtAcc_fake.layout) + return tCtAcc + + @cute.jit + def exec_mma( + self, + tiled_mma, + tCtAcc, + tCrA, + tCrB, + acc_producer_state, + a_consumer_state, + b_consumer_state, + ): + for kphase_idx in cutlass.range(cute.size(tCrB, mode=[2]), unroll_full=True): + # set accu = 1 + tiled_mma.set( + tcgen05.Field.ACCUMULATE, + cutlass.Boolean(kphase_idx != 0), + ) + cute.gemm( + tiled_mma, + tCtAcc[None, None, None, acc_producer_state.index], + tCrA[None, None, kphase_idx, a_consumer_state.index], + tCrB[None, None, kphase_idx, b_consumer_state.index], + tCtAcc[None, None, None, acc_producer_state.index], + ) + return tiled_mma + + @cute.jit + def conditional_consumer_try_wait(self, b_consumer_state, b_pipeline, C): + peek_b_full_status = cutlass.Boolean(1) + if b_consumer_state.count < C: + peek_b_full_status = b_pipeline.consumer_try_wait(b_consumer_state) + return peek_b_full_status + + @cute.jit + def conditional_producer_try_acquire( + self, intra1_acc_producer_state, intra1_acc_pipeline, C + ): + peek_wr_intra1_acc_empty_status = cutlass.Boolean(1) + if intra1_acc_producer_state.count < C: + peek_wr_intra1_acc_empty_status = intra1_acc_pipeline.producer_try_acquire( + intra1_acc_producer_state + ) + return peek_wr_intra1_acc_empty_status + + def pre_intra_tmem_load_and_partition_q(self, tIntra1, local_tidx): + copy_atom_t2r_intra1 = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(16), tcgen05.Pack.NONE), + self.acc_dtype, + ) + # (L, L) + fake_sQ = cute.make_tensor( + cute.make_ptr(self.io_dtype, 0, cute.AddressSpace.smem), + cute.dice(self.tile_shape_mnk_intra1, (1, 1, None)), + ) + return self.make_tmem_load_and_partition( + copy_atom_t2r_intra1, tIntra1, (None, None, 0), local_tidx, fake_sQ + ) + + def pre_intra_make_delta(self, smem_delta, extend_on_row_or_col): + smem_iterator = smem_delta.iterator + delta_linear_smem_layout = smem_delta.layout + # extend L linear layout to LxL + extend_layout = cute.make_layout(delta_linear_smem_layout.shape[0], stride=0) + if extend_on_row_or_col == 0: + # (L, L, INPUT_STAGE):(0, 1, L) + sDelta = cute.make_tensor( + smem_iterator, + cute.prepend( + delta_linear_smem_layout, + extend_layout, + ), + ) + else: + # (L, L, INPUT_STAGE):(1, 0, L) + sDelta = cute.make_tensor( + smem_iterator, + cute.append( + cute.append( + cute.get(delta_linear_smem_layout, mode=[0]), + extend_layout, + ), + cute.get(delta_linear_smem_layout, mode=[1]), + ), + ) + return sDelta + + def pre_intra_tmem_store_and_partition_q(self, local_tidx, tCrQ): + dtype = tCrQ.element_type + # Make tiledCopy for tensor memory store INTRA2_Q + copy_atom_r2t_q = cute.make_copy_atom( + tcgen05.St16x128bOp(tcgen05.Repetition(16), tcgen05.Unpack.NONE), + dtype, + ) + tiled_r2t_q = tcgen05.make_tmem_copy(copy_atom_r2t_q, tCrQ) + thr_r2t_q = tiled_r2t_q.get_slice(local_tidx) + + # Partition tmem/register tensor for tensor memory store INTRA2_Q + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ...) + tRT_rQ = cute.make_fragment( + cute.slice_(thr_r2t_q.partition_S(tCrQ).shape, (None, None, None, None, 0)), + dtype, + ) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ..., INTERNAL_STAGE) + tRT_tQ = thr_r2t_q.partition_D(tCrQ) + + return tiled_r2t_q, tRT_rQ, tRT_tQ + + @cute.jit + def pre_intra_segsum( + self, tTR_rQ, tQrDeltaA_Row, tQrDeltaA_Col, tQrDelta, tCoord, tRT_rQ + ): + # Make tmp acc type fragments + tCrDeltaA_Row = cute.make_fragment(tQrDeltaA_Row.shape, self.acc_dtype) + tCrDeltaA_Col = cute.make_fragment(tQrDeltaA_Col.shape, self.acc_dtype) + tCrDelta = cute.make_fragment(tQrDelta.shape, self.acc_dtype) + tCompute = cute.make_fragment(tRT_rQ.shape, self.acc_dtype) + + # Combine tTR_rQ/tCrDeltaA_Row/tCrDeltaA_Col/tCrDelta + tCrDeltaA_Row.store(tQrDeltaA_Row.load().to(self.acc_dtype)) + tCrDeltaA_Col.store(tQrDeltaA_Col.load().to(self.acc_dtype)) + tCrDelta.store(tQrDelta.load().to(self.acc_dtype)) + + # SegSum + # fadd2 + fsel + fmul2/mufu + fmul2 + for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True): + ( + tCompute[subtile_idx], + tCompute[subtile_idx + 1], + ) = cute.arch.add_packed_f32x2( + (tCrDeltaA_Col[subtile_idx], tCrDeltaA_Col[subtile_idx + 1]), + (-tCrDeltaA_Row[subtile_idx], -tCrDeltaA_Row[subtile_idx + 1]), + ) + for subtile_idx in cutlass.range(cute.size(tTR_rQ), unroll_full=True): + m, n = tCoord[subtile_idx] + if m < n: + tCompute[subtile_idx] = cutlass.Float32(-float("inf")) + LOG2_E = cutlass.Float32(1.4426950408889634) + for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True): + # TODO: use math.exp directly + tCompute_log2e = cute.arch.mul_packed_f32x2( + (tCompute[subtile_idx], tCompute[subtile_idx + 1]), (LOG2_E, LOG2_E) + ) + ( + tCompute[subtile_idx], + tCompute[subtile_idx + 1], + ) = cute.arch.mul_packed_f32x2( + ( + cute.math.exp2(tCompute_log2e[0], fastmath=True), + cute.math.exp2(tCompute_log2e[1], fastmath=True), + ), + (tCrDelta[subtile_idx], tCrDelta[subtile_idx + 1]), + ) + ( + tCompute[subtile_idx], + tCompute[subtile_idx + 1], + ) = cute.arch.mul_packed_f32x2( + (tCompute[subtile_idx], tCompute[subtile_idx + 1]), + (tTR_rQ[subtile_idx], tTR_rQ[subtile_idx + 1]), + ) + + tRT_rQ.store(tCompute.load().to(self.io_dtype)) + return tRT_rQ + + def pre_inter_smem_load_and_partition_b(self, local_tidx, smem_bt): + dtype = smem_bt.element_type + copy_atom_s2r_b = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + dtype, + num_bits_per_copy=128, + ) + num_elements_per_thread = 128 // dtype.width + num_threads_per_row = self.tile_shape_mnk_inter1[2] // num_elements_per_thread + num_threads_per_col = 128 // num_threads_per_row + thread_layout = cute.make_layout( + (num_threads_per_col, num_threads_per_row), + stride=(num_threads_per_row, 1), + ) + val_layout = cute.make_layout((1, num_elements_per_thread)) + tiled_s2r_b = cute.make_tiled_copy_tv( + copy_atom_s2r_b, + thread_layout, + val_layout, + ) + thr_s2r_b = tiled_s2r_b.get_slice(local_tidx) + + # Partition shared tensor for smem load Bt + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + tBsB_s2r = thr_s2r_b.partition_S(smem_bt) + + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tBrB_s2r = cute.make_fragment( + cute.slice_(tBsB_s2r.shape, (None, None, None, 0)), + dtype, + ) + return tiled_s2r_b, tBsB_s2r, tBrB_s2r + + def pre_inter_smem_store_and_partition_b( + self, local_tidx, smem_bt_internal, tiled_s2r_b, tBrB_s2r + ): + dtype = smem_bt_internal.element_type + # Make tiledCopy from register to smem store Bt + copy_atom_r2s_b = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + dtype, + num_bits_per_copy=128, + ) + tiled_r2s_b = cute.make_tiled_copy_S(copy_atom_r2s_b, tiled_s2r_b) + thr_r2s_b = tiled_r2s_b.get_slice(local_tidx) + + # Partition shared tensor for smem store Bt + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tBsB_r2s = thr_r2s_b.partition_D(smem_bt_internal) + + # Make register fragments for smem load/store Bt + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tBrB_r2s = thr_r2s_b.retile(tBrB_s2r) + return tiled_r2s_b, tBrB_r2s, tBsB_r2s + + def smem_load_and_partition_delta_d( + self, tiled_s2r_b, local_tidx, smem_delta, smem_tile_coord + ): + dtype = smem_delta.element_type + s2r_atom_delta = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), dtype) + + thr_s2r_b = tiled_s2r_b.get_slice(local_tidx) + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + tBsDelta_s2r = thr_s2r_b.partition_D(smem_delta) + + # Make register fragments for smem load/store of Delta/DeltaA + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tBrDelta_s2r = cute.make_fragment(tBsDelta_s2r[smem_tile_coord].shape, dtype) + return s2r_atom_delta, tBsDelta_s2r, tBrDelta_s2r + + def pre_inter_tmem_load_and_partition_p(self, local_tidx, tInter1, smem_pt): + copy_atom_t2r_inter1 = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(8), tcgen05.Pack.NONE), + self.acc_dtype, + ) + return self.make_tmem_load_and_partition( + copy_atom_t2r_inter1, + tInter1, + (None, None, 0), + local_tidx, + smem_pt[None, None, 0], + ) + + def make_tmem_load_and_partition( + self, copy_atom_t2r, tmem_tensor, tmem_tile_coord, local_tidx, smem_tensor + ): + dtype = tmem_tensor.element_type + tiled_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tmem_tensor[tmem_tile_coord]) + thr_t2r = tiled_t2r.get_slice(local_tidx) + # Partition tmem/shared tensor for tmem load INTER1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tTR_t = thr_t2r.partition_S(tmem_tensor) + tTR_s = thr_t2r.partition_D(smem_tensor) + # Make register fragments for tmem load INTER1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tTR_r = cute.make_fragment( + tTR_s.shape, + dtype, + ) + return tiled_t2r, tTR_t, tTR_r + + def smem_store_and_partition_p_y(self, local_tidx, smem_pt, tiled_t2r_inter1): + dtype = smem_pt.element_type + copy_atom_r2s_p = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=True, num_matrices=4), + dtype, + ) + tiled_r2s_p = cute.make_tiled_copy_D(copy_atom_r2s_p, tiled_t2r_inter1) + thr_r2s_p = tiled_r2s_p.get_slice(local_tidx) + # Partition smem/register tensor for smem store INTER2_P + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tRS_sP = thr_r2s_p.partition_D(smem_pt) + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + tRS_rP = cute.make_fragment( + cute.slice_(tRS_sP.shape, (None, None, None, 0)), self.io_dtype + ) + return tiled_r2s_p, tRS_rP, tRS_sP + + def pre_inter_make_delta(self, smem_delta, smem_bt_layout): + # Broadcast Delta/DeltaA to Bt shape on M dimension + # before: (128,(64,2),2):(64,(1,8192),16384) + # after : (128,(64,2),2):(0,(1,64),128) + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + sDeltaA = cute.make_tensor( + smem_delta.iterator, + cute.make_layout( + smem_bt_layout.shape, + stride=( + 0, + (1, cute.get(smem_bt_layout.shape, mode=[1, 0])), + smem_delta.layout.stride[1], + ), + ), + ) + return sDeltaA + + def pre_inter_scale_bt_with_delta( + self, tBrB_s2r, tBrDelta_s2r, tBrDeltaA_s2r, last_column + ): + tCompute = cute.make_fragment(tBrB_s2r.shape, self.acc_dtype) + tBrB_Compute = cute.make_fragment(tBrB_s2r.shape, self.acc_dtype) + tBrDelta_Compute = cute.make_fragment(tBrDelta_s2r.shape, self.acc_dtype) + tBrDeltaA_Compute = cute.make_fragment(tBrDeltaA_s2r.shape, self.acc_dtype) + + tBrB_Compute.store(tBrB_s2r.load().to(self.acc_dtype)) + tBrDelta_Compute.store(tBrDelta_s2r.load().to(self.acc_dtype)) + tBrDeltaA_Compute.store(tBrDeltaA_s2r.load().to(self.acc_dtype)) + + for reg_idx in range(0, cute.size(tBrB_Compute), 2): + tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2( + ( + cute.math.exp( + (last_column - tBrDeltaA_Compute[reg_idx]), fastmath=True + ), + cute.math.exp( + (last_column - tBrDeltaA_Compute[reg_idx + 1]), fastmath=True + ), + ), + (tBrDelta_Compute[reg_idx], tBrDelta_Compute[reg_idx + 1]), + ) + tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2( + (tCompute[reg_idx], tCompute[reg_idx + 1]), + (tBrB_Compute[reg_idx], tBrB_Compute[reg_idx + 1]), + ) + return tCompute + + def epilog_make_delta(self, smem_cumsum_delta): + # Broadcast cumsum delta from LxINPUT_STAGE to LxDxINPUT_STAGE + sDeltaA = cute.make_tensor( + smem_cumsum_delta.iterator, + cute.make_layout( + (*self.tile_shape_mnk_inter2[:2], self.input_stages), + stride=(1, 0, smem_cumsum_delta.layout.shape[0]), + ), + ) + return sDeltaA + + def epilog_make_d(self, smem_d): + # Broadcast d from DxINPUT_STAGE to LxDxINPUT_STAGE + sD = cute.make_tensor( + smem_d.iterator, + cute.make_layout( + (*self.tile_shape_mnk_inter2[:2], self.input_stages), + stride=(0, 1, smem_d.layout.shape[0]), + ), + ) + return sD + + def epilog_tma_partition_y(self, tma_tensor_y, tma_atom_y, smem_y, epi_tile): + # Local_tile partition global tensors + # (L, D, 1, 1, C, EH, B) + gY = cute.local_tile( + tma_tensor_y, + cute.slice_(self.tile_shape_mnk_inter2, (None, None, 0)), + (None, None, None, None, None), + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, 1, 1, C, EH, B) + gY_epi = cute.flat_divide(gY, epi_tile) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), EPI_M, EPI_N, 1, 1, C, EH, B) + bSG_sY, bSG_gY_pre_slice = cpasync.tma_partition( + tma_atom_y, + 0, + cute.make_layout(1), + cute.group_modes(smem_y, 0, 2), + cute.group_modes(gY_epi, 0, 2), + ) + return bSG_sY, bSG_gY_pre_slice + + def epilog_smem_load_and_partition_x( + self, tiled_t2r_inter2_intra2, local_tidx, smem_xt, epi_tile + ): + dtype = smem_xt.element_type + copy_atom_s2r_x = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + dtype, + ) + tiled_s2r_x = cute.make_tiled_copy_D(copy_atom_s2r_x, tiled_t2r_inter2_intra2) + thr_s2r_x = tiled_s2r_x.get_slice(local_tidx) + # Partition smem/register tensor for smem store INTER2_P + # (R2S_ATOM, R2S_M, R2S_N, EPI_M, EPI_N, INPUT_STAGES) + tSR_sX = thr_s2r_x.partition_S(cute.flat_divide(smem_xt, epi_tile)) + # (R2S_ATOM, R2S_M, R2S_N) + tSR_rX = cute.make_fragment( + cute.slice_(tSR_sX.shape, (None, None, None, 0, 0, 0)), dtype + ) + return tiled_s2r_x, tSR_sX, tSR_rX + + def epilog_tmem_load_and_partition_acc(self, local_tidx, tIntra, smem_y): + copy_atom_t2r_inter2_intra2 = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(4), tcgen05.Pack.NONE), + self.acc_dtype, + ) + return self.make_tmem_load_and_partition( + copy_atom_t2r_inter2_intra2, + tIntra, + (None, None, 0, 0, 0), + local_tidx, + smem_y[None, None, 0], + ) + + +def run( + gbehcdln: Tuple[int, int, int, int, int, int, int, int], + io_dtype: Type[cutlass.Numeric], + cumsum_delta_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + fuse_scale_d: str, + tolerance: float, + print_rtol_stats: bool, + ref_lower_precision: bool, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + has_d = fuse_scale_d != "none" + d_has_hdim = fuse_scale_d == "vector" + + print(f"Running B100 Mamba2 SSD with:") + print(f"GBEHCDLN: {gbehcdln}") + print( + f"Input/Output dtype: {io_dtype}, Intermediate delta dtype: {cumsum_delta_dtype}, Acc dtype: {acc_dtype}" + ) + print( + f"Has D (True means fuse Y+=X*D): {has_d}, D has Hdim (True means D.shape DxEH, False means 1xEH): {d_has_hdim}" + ) + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Unpack parameters + G, B, E, H, C, D, L, N = gbehcdln + EH = E * H + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + # Match same seed in ssd_reference.py for reference check + torch.manual_seed(42) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + shape, permute_order, dtype, dt_or_a=0, dynamic_modes=None, ref_tensor=None + ): + # Build fp32 reference torch tensor + if ref_tensor is None: + ref_tensor = ( + torch.empty(*shape, dtype=torch.float32) + # .random_(-1, 1) + .normal_(0, 0.5) + # .uniform_(-1,1) + .permute(permute_order) + ) + if dt_or_a == 1: # dt: + ref_tensor = F.softplus(ref_tensor - 4) + elif dt_or_a == 2: # A: + ref_tensor = -torch.exp(ref_tensor) + + # Build torch_dtype torch tensor + torch_dtype = cutlass_torch.dtype(dtype) + + dst_tensor = ref_tensor.to(torch_dtype).cuda() + cute_tensor = from_dlpack(dst_tensor, assumed_align=16) + for mode in dynamic_modes: + cute_tensor = cute_tensor.mark_compact_shape_dynamic( + mode=mode, stride_order=dst_tensor.dim_order() + ) + + return ref_tensor, cute_tensor, dst_tensor + + # INPUT tensors + # x: (D, L, C, EH, B):(C*L, 1, L, D*C*L, EH*D*C*L) + x_ref, x_tensor, x_torch = create_and_permute_tensor( + [B, EH, D, C, L], [2, 4, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + # delta/delta_a/cumsum_delta: (L, C, EH, B):(1, L, C*L, EH*C*L) + delta_ref, delta_tensor, delta_torch = create_and_permute_tensor( + [B, EH, C, L], [3, 2, 1, 0], io_dtype, dt_or_a=1, dynamic_modes=[1, 2, 3] + ) + # a: (EH):(1) + a_ref, a_tensor, a_torch = create_and_permute_tensor( + [EH], [0], io_dtype, dt_or_a=2, dynamic_modes=[0] + ) + + if has_d: + # d: (D, EH):(1, D) or (1, EH):(0, 1) + d_ref, d_tensor, d_torch = create_and_permute_tensor( + [EH, D if d_has_hdim else 1], [1, 0], io_dtype, dynamic_modes=[1] + ) + else: + d_ref = None + d_tensor = None + + # b/c: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L) + b_ref, b_tensor, b_torch = create_and_permute_tensor( + [B, G, N, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + c_ref, c_tensor, c_torch = create_and_permute_tensor( + [B, G, N, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + + # OUTPUT tensors + # y: (L, D, C, EH, B):(1, C*L, L, D*C*L, EH*D*C*L) + y_ref, y_tensor, y_torch = create_and_permute_tensor( + [B, EH, D, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + # fstate: (D, N, EH, B):(N, 1, D*N, EH*D*N) + fstate_ref, fstate_tensor, fstate_torch = create_and_permute_tensor( + [B, EH, D, N], [2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3] + ) + + # Call pytorch reference on cpu + if not ref_lower_precision: + ssd_reference_fp32_all( + x_ref, + a_ref, + delta_ref, + b_ref, + c_ref, + y_ref, + fstate_ref, + d_ref, + has_d, + d_has_hdim, + ) + else: + ssd_reference_lowprecision_intermediates( + x_ref, + a_ref, + delta_ref, + b_ref, + c_ref, + y_ref, + fstate_ref, + cutlass_torch.dtype(io_dtype), + d_ref, + has_d, + d_has_hdim, + ) + + # Compute cumsum with pytorch on cpu + delta_a_ref = delta_ref * a_ref.view(1, 1, -1, 1) + cumsum_delta_ref = torch.empty([B, EH, C, L], dtype=torch.float32).permute( + [3, 2, 1, 0] + ) + cumsum_delta_ref.copy_(torch.cumsum(delta_a_ref, dim=0).permute([0, 1, 2, 3])) + # Copy cumsum_delta_ref to cumsum_delta_tensor + ( + cumsum_delta_ref, + cumsum_delta_tensor, + cumsum_delta_torch, + ) = create_and_permute_tensor( + [B, EH, C, L], + [3, 2, 1, 0], + cumsum_delta_dtype, + ref_tensor=cumsum_delta_ref, + dynamic_modes=[1, 2, 3], + ) + + # Call fused ssd kernel + ssd = SSDKernel( + io_dtype, + cumsum_delta_dtype, + acc_dtype, + L, + D, + N, + has_d, + d_has_hdim, + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(1) + + stream = cutlass.cuda.default_stream() + # Compile ssd kernel + compiled_ssd = cute.compile( + ssd, + x_tensor, + cumsum_delta_tensor, + delta_tensor, + b_tensor, + c_tensor, + y_tensor, + fstate_tensor, + d_tensor, + max_active_clusters, + stream, + ) + + # Launch compiled ssd kernel for reference check + if not skip_ref_check: + compiled_ssd( + x_tensor, + cumsum_delta_tensor, + delta_tensor, + b_tensor, + c_tensor, + y_tensor, + fstate_tensor, + d_tensor, + stream, + ) + + # Reference check + if print_rtol_stats: + print("\nY's Relative diffs:") + analyze_relative_diffs( + y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype)) + ) + print("\nFstate's Relative diffs:") + analyze_relative_diffs( + fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype)) + ) + torch.testing.assert_close( + y_torch.cpu(), + y_ref.to(cutlass_torch.dtype(io_dtype)), + atol=tolerance, + rtol=1e-02, + ) + torch.testing.assert_close( + fstate_torch.cpu(), + fstate_ref.to(cutlass_torch.dtype(io_dtype)), + atol=tolerance, + rtol=1e-05, + ) + + def generate_tensors(): + # Reuse existing CPU reference tensors and create new GPU tensors from them + _, x_tensor_new, _ = create_and_permute_tensor( + [B, EH, D, C, L], + [2, 4, 3, 1, 0], + io_dtype, + ref_tensor=x_ref, + dynamic_modes=[2, 3, 4], + ) + _, cumsum_delta_tensor_new, _ = create_and_permute_tensor( + [B, EH, C, L], + [3, 2, 1, 0], + cumsum_delta_dtype, + ref_tensor=cumsum_delta_ref, + dynamic_modes=[1, 2, 3], + ) + _, delta_tensor_new, _ = create_and_permute_tensor( + [B, EH, C, L], + [3, 2, 1, 0], + io_dtype, + ref_tensor=delta_ref, + dynamic_modes=[1, 2, 3], + ) + _, b_tensor_new, _ = create_and_permute_tensor( + [B, G, N, C, L], + [4, 2, 3, 1, 0], + io_dtype, + ref_tensor=b_ref, + dynamic_modes=[2, 3, 4], + ) + _, c_tensor_new, _ = create_and_permute_tensor( + [B, G, N, C, L], + [4, 2, 3, 1, 0], + io_dtype, + ref_tensor=c_ref, + dynamic_modes=[2, 3, 4], + ) + _, y_tensor_new, _ = create_and_permute_tensor( + [B, EH, D, C, L], + [4, 2, 3, 1, 0], + io_dtype, + ref_tensor=y_ref, + dynamic_modes=[2, 3, 4], + ) + _, fstate_tensor_new, _ = create_and_permute_tensor( + [B, EH, D, N], + [2, 3, 1, 0], + io_dtype, + ref_tensor=fstate_ref, + dynamic_modes=[2, 3], + ) + + if has_d: + _, d_tensor_new, _ = create_and_permute_tensor( + [EH, D if d_has_hdim else 1], + [1, 0], + io_dtype, + ref_tensor=d_ref, + dynamic_modes=[1], + ) + else: + d_tensor_new = d_tensor + + return testing.JitArguments( + x_tensor_new, + cumsum_delta_tensor_new, + delta_tensor_new, + b_tensor_new, + c_tensor_new, + y_tensor_new, + fstate_tensor_new, + d_tensor_new, + stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + x_torch.numel() * x_torch.element_size() + + cumsum_delta_torch.numel() * cumsum_delta_torch.element_size() + + delta_torch.numel() * delta_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + c_torch.numel() * c_torch.element_size() + + y_torch.numel() * y_torch.element_size() + + fstate_torch.numel() * fstate_torch.element_size() + ) + if has_d: + one_workspace_bytes += d_torch.numel() * d_torch.element_size() + + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_ssd, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> List[int]: + try: + return [int(x.strip()) for x in s.split(",")] + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of MxNxKxL GEMM on Blackwell." + ) + + parser.add_argument( + "--gbehcdln", + type=parse_comma_separated_ints, + default=[2, 4, 2, 40, 32, 64, 128, 128], + # default=[2, 3, 2, 2, 8, 64, 128, 128], + # default=[1, 2, 1, 4, 8, 64, 128, 128], + help="gbehcdln dimensions (comma-separated)", + ) + parser.add_argument("--io_dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument( + "--cumsum_delta_dtype", type=cutlass.dtype, default=cutlass.Float32 + ) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--fuse_scale_d", + type=str, + choices=["none", "scalar", "vector"], + default="vector", + help="Fuse scale type: none (no Y+=X*D fusion), scalar (Y+=X*D fusion with D.shape=1xEH), or vector (Y+=X*D fusion with D.shape=DxEH)", + ) + parser.add_argument( + "--ref_lower_precision", + action="store_true", + default=True, + help="Use lower precision for reference check", + ) + parser.add_argument( + "--no-ref_lower_precision", + action="store_false", + dest="ref_lower_precision", + default=False, + help="Disable lower precision for reference check", + ) + parser.add_argument( + "--tolerance", type=float, default=5e-02, help="Tolerance for validation" + ) + parser.add_argument( + "--print_rtol_stats", + action="store_true", + default=True, + help="Enable print rtol stats", + ) + parser.add_argument( + "--no-print_rtol_stats", + action="store_false", + dest="print_rtol_stats", + default=False, + help="Disable print rtol stats", + ) + parser.add_argument( + "--warmup_iterations", + type=int, + default=0, + help="Number of warmup iterations", + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.gbehcdln) != 8: + parser.error("--gbehcdln must contain exactly 8 values") + + run( + args.gbehcdln, + args.io_dtype, + args.cumsum_delta_dtype, + args.acc_dtype, + args.fuse_scale_d, + args.tolerance, + args.print_rtol_stats, + args.ref_lower_precision, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py new file mode 100644 index 0000000000000000000000000000000000000000..ced5c6f2f51fb395b4efac96ee2b655710fca463 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py @@ -0,0 +1,397 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import torch.nn.functional as F + + +def ssd_reference_fp32_all(x, a, delta, B, C, Y_out, Fstate_out, D, has_d, d_has_hdim): + """ + Rearrange tensor dimensions from cuda layout to reference layout, then directly call TriDao's ssd implementation + Arguments: + X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L) + A/delta: (L, C, H, B):(1, L, C*L, H*C*L) + a: (H):(1) + B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L) + D: (1, H):(0, 1) or (D, H):(1, D) + has_d: bool + d_has_hdim: bool + Return: + Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L) + Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N) + """ + assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype + + A = delta * a.view(1, 1, -1, 1) + X = x * delta.unsqueeze(0) + + # Rearrange to match cutlass layout to tridao's layout + block_len = A.shape[0] + initial_states = None + # A: l c h b-> b c l h + A = A.permute(3, 1, 0, 2) + # X: p l c h b -> b c l h p + X = X.permute(4, 2, 1, 3, 0) + # B: l n c g b -> b c l g n + B = B.permute(4, 2, 0, 3, 1) + # C: l n c g b -> b c l g n + C = C.permute(4, 2, 0, 3, 1) + # X/A/B/C: b c l ... -> b (c l) ... + X, A, B, C = [x.reshape(x.shape[0], -1, *x.shape[3:]) for x in (X, A, B, C)] + + # Ngroup (g to h) mapping + B_val, CL_val, G_val, N_val = B.shape + H_val = X.shape[2] + ngroup_ratio = H_val // G_val + # B/C: (B, CL, H, N) + h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio + B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + + ################################################################### + # Call reference implementation from Tri Dao ssd_minimal_discrete + Y, final_state = ssd_minimal_discrete_fp32_all( + X, A, B, C, block_len, initial_states + ) + ################################################################### + + if has_d: + D_val = Y.shape[3] + if not d_has_hdim: + D = D.expand(D_val, -1) + Y = Y + torch.einsum("bchp,ph->bchp", X, D) + + # Rearrange to match tridao's layout to cutlass layout + # Y: b (c l) h p -> b c l h p + Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3]) + # Y: b c l h p -> l p c h b + Y = Y.permute(2, 4, 1, 3, 0) + # Fstate_out: b h p n -> p n h b + Fstate_out.copy_(final_state.permute(2, 3, 1, 0)) + Y_out.copy_(Y) + return + + +def ssd_reference_lowprecision_intermediates( + x, a, delta, B, C, Y_out, Fstate_out, intermediate_dtype, D, has_d, d_has_hdim +): + """ + Rearrange tensor dimensions from cuda layout to reference layout, then call a reduced intermediate dtype version of ssd implementation + Arguments: + X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L) + A/delta: (L, C, H, B):(1, L, C*L, H*C*L) + a: (H):(1) + B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L) + intermediate_dtype: input and intermediate data type + D: (1, H):(0, 1) or (D, H):(1, D) + has_d: bool + d_has_hdim: bool + Return: + Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L) + Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N) + """ + assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype + + A = delta * a.view(1, 1, -1, 1) + + # Rearrange to match cutlass layout to tridao's layout + block_len = A.shape[0] + initial_states = None + # A: l c h b-> b c l h + A = A.permute(3, 1, 0, 2) + # delta: l c h b-> b c l h + delta = delta.permute(3, 1, 0, 2) + # x: p l c h b -> b c l h p + x = x.permute(4, 2, 1, 3, 0) + # B: l n c g b -> b c l g n + B = B.permute(4, 2, 0, 3, 1) + # C: l n c g b -> b c l g n + C = C.permute(4, 2, 0, 3, 1) + # x/A/delta/B/C: b c l ... -> b (c l) ... + x, A, delta, B, C = [ + tensor.reshape(tensor.shape[0], -1, *tensor.shape[3:]) + for tensor in (x, A, delta, B, C) + ] + + # Ngroup (g to h) mapping + B_val, CL_val, G_val, N_val = B.shape + H_val = x.shape[2] + ngroup_ratio = H_val // G_val + # B/C: (B, CL, H, N) + h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio + B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + + # Type convert input tensors to input dtype (same as intermediate dtype) + x = x.to(intermediate_dtype).to(torch.float32) + A = A.to(intermediate_dtype).to(torch.float32) + delta = delta.to(intermediate_dtype).to(torch.float32) + B = B.to(intermediate_dtype).to(torch.float32) + C = C.to(intermediate_dtype).to(torch.float32) + + ######################################################################### + # Call reference implementation ssd_minimal_discrete_bf16_intermediates + Y, final_state = ssd_minimal_discrete_lowprecision_intermediates( + x, A, delta, B, C, block_len, intermediate_dtype, initial_states + ) + ######################################################################### + + if has_d: + D = D.to(intermediate_dtype).to(torch.float32) + D_val = Y.shape[3] + if not d_has_hdim: + D = D.expand(D_val, -1) + Y = Y + torch.einsum("bchp,ph->bchp", x, D) + + # Type convert output tensors to output dtype (same as intermediate dtype) + Y = Y.to(intermediate_dtype).to(torch.float32) + final_state = final_state.to(intermediate_dtype).to(torch.float32) + + # Rearrange to match tridao's layout to cutlass layout + # Y: b (c l) h p -> b c l h p + Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3]) + # Y: b c l h p -> l p c h b + Y = Y.permute(2, 4, 1, 3, 0) + # Fstate_out: b h p n -> p n h b + Fstate_out.copy_(final_state.permute(2, 3, 1, 0)) + Y_out.copy_(Y) + return + + +def analyze_relative_diffs(actual, expected): + """ + Print statistics of relative differences between actual and expected tensors + """ + # Calculate relative differences + abs_diff = (actual - expected).abs() + rel_diff = abs_diff / (torch.maximum(expected.abs(), actual.abs()) + 0.00001) + + total_elements = rel_diff.numel() + + # Handle special cases first + nan_mask = torch.isnan(rel_diff) + inf_mask = torch.isinf(rel_diff) + nan_count = nan_mask.sum().item() + inf_count = inf_mask.sum().item() + + # Find position and value of maximum relative difference + max_rel_diff = ( + rel_diff[~nan_mask & ~inf_mask].max() + if (~nan_mask & ~inf_mask).any() + else float("nan") + ) + max_rel_diff_pos = ( + rel_diff[~nan_mask & ~inf_mask].argmax() + if (~nan_mask & ~inf_mask).any() + else -1 + ) + + # Print max relative difference info + print(f"Maximum relative difference:") + print(f"Position: {max_rel_diff_pos}") + print(f"Value: {max_rel_diff:.6e}") + print(f"Actual value: {actual.flatten()[max_rel_diff_pos]}") + print(f"Expected value: {expected.flatten()[max_rel_diff_pos]}") + print(f"NaN values: {nan_count} ({100.0 * nan_count / total_elements:.2f}%)") + print(f"Inf values: {inf_count} ({100.0 * inf_count / total_elements:.2f}%)\n") + + # Check different rtol thresholds + rtol_levels = [1e-5, 1e-4, 1e-3, 1e-2, 5e-02, 1e-01] + + for i, rtol in enumerate(rtol_levels): + if i == 0: + mask = rel_diff <= rtol + else: + mask = (rel_diff <= rtol) & (rel_diff > rtol_levels[i - 1]) + + count = mask.sum().item() + percentage = (count / total_elements) * 100 + + if i == 0: + print(f"Elements with rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)") + else: + print( + f"Elements with {rtol_levels[i-1]:.0e} < rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)" + ) + + # Print elements exceeding the largest rtol + mask = rel_diff > rtol_levels[-1] + count = mask.sum().item() + percentage = (count / total_elements) * 100 + print(f"Elements with rtol > {rtol_levels[-1]:.0e}: {count} ({percentage:.2f}%)\n") + + +def segsum(x): + """ + More stable segment sum calculation. + x: b h c l + """ + T = x.size(-1) + # x: b h c l -> b h c l l + x = x.unsqueeze(-1).expand(*x.shape, T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_minimal_discrete_fp32_all(X, A, B, C, block_len, initial_states=None): + """ + This is same with https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py + (all accumulation and intermediate results in fp32) + + Arguments: + X: (batch(B), length(C*L), n_heads(H), d_head(D)) + A: (batch(B), length(C*L), n_heads(H)) + B: (batch(B), length(C*L), n_heads(H), d_state(N)) + C: (batch(B), length(C*L), n_heads(H), d_state(N)) + Return: + Y: (batch(B), length(C*L), n_heads(H), d_head(D)) + final_state: (B, H, D, N) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + # X/A/B/C:b (c l) ... -> b c l ... + X, A, B, C = [ + x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, B, C) + ] + + # A: b c l h -> b h c l + A = A.permute(0, 3, 1, 2) + # A_cumsum: (B, H, C, L) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + segsum_A = segsum(A) + L = torch.exp(segsum_A) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Y: b c l h p -> b (c l) h p + Y = (Y_diag + Y_off).reshape(Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4]) + return Y, final_state + + +def ssd_minimal_discrete_lowprecision_intermediates( + X, A, delta, B, C, block_len, intermediate_dtype, initial_states=None +): + """ + This is adjusted from ssd_minimal_discrete_fp32_all, with exceptions: + 1. accumulation in fp32 but intermediates Q/b_tmem/P are in intermediate_dtype + 2. delta is not pre-multiplied with X, delta was applied to generate Q/b_tmem to match GPU implementation + + Arguments: + X: (batch(B), length(C*L), n_heads(H), d_head(D)) + A: (batch(B), length(C*L), n_heads(H)) + delta: (batch(B), length(C*L), n_heads(H)) + B: (batch(B), length(C*L), n_heads(H), d_state(N)) + C: (batch(B), length(C*L), n_heads(H), d_state(N)) + Return: + Y: (batch(B), length(C*L), n_heads(H), d_head(D)) + final_state: (B, H, D, N) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + # X/A/delta/B/C: b (c l) ... -> b c l ... + X, A, delta, B, C = [ + x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, delta, B, C) + ] + + # A: b c l h -> b h c l + A = A.permute(0, 3, 1, 2) + # delta: b c l h -> b h c l + delta = delta.permute(0, 3, 1, 2) + # A_cumsum: (B, H, C, L) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + segsum_A = segsum(A) + L = torch.exp(segsum_A) + intra_acc_0 = torch.einsum("bclhn,bcshn->bclhs", C, B) + Q = torch.einsum("bclhs,bhcls,bhcs->bclhs", intra_acc_0, L, delta) + Y_diag = torch.einsum( + "bclhs,bcshp->bclhp", Q.to(intermediate_dtype).to(torch.float32), X + ) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + b_tmem = torch.einsum("bclhn,bhcl,bhcl->bclhn", B, decay_states, delta) + states = torch.einsum( + "bclhn,bclhp->bchpn", b_tmem.to(intermediate_dtype).to(torch.float32), X + ) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + final_state = final_state + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off_tmp = torch.einsum( + "bclhn,bchpn->bclhp", C, states.to(intermediate_dtype).to(torch.float32) + ) + Y_off = torch.einsum("bclhp,bhcl->bclhp", Y_off_tmp, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Y: b c l h p -> b (c l) h p + Y = (Y_diag + Y_off).reshape( + Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4] + ) # b (c l) h p + return Y, final_state diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..544b477228132922c38c9855d44ad85d94f2effb --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Tuple + +from cutlass.cutlass_dsl import ( + Boolean, + Integer, + Int32, + min, + extract_mlir_values, + new_from_mlir_values, + dsl_user_op, +) +from cutlass._mlir import ir +import cutlass.cute as cute +from cutlass.utils import WorkTileInfo + + +class Mamba2SSDTileSchedulerParams: + def __init__( + self, + problem_shape_ntiles: int, + eh: int, + ngroup_ratio: int, + *, + loc=None, + ip=None, + ): + self.problem_shape_ntiles = problem_shape_ntiles + self.eh = eh + self.ngroup_ratio = ngroup_ratio + self._loc = loc + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.problem_shape_ntiles, self.eh, self.ngroup_ratio]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.problem_shape_ntiles, self.eh, self.ngroup_ratio], self._values_pos + ): + obj_list.append(new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return Mamba2SSDTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + @dsl_user_op + def get_grid_shape( + self, max_active_clusters: Int32, *, loc=None, ip=None + ) -> Tuple[Integer, Integer, Integer]: + return (min(self.problem_shape_ntiles, max_active_clusters), 1, 1) + + +class Mamba2SSDTileScheduler: + def __init__( + self, + params: Mamba2SSDTileSchedulerParams, + num_persistent_ctas: Int32, + current_work_linear_idx: Int32, + num_tiles_executed: Int32, + ): + self.params = params + self.num_persistent_ctas = num_persistent_ctas + self._current_work_linear_idx = current_work_linear_idx + self._num_tiles_executed = num_tiles_executed + + def __extract_mlir_values__(self) -> list[ir.Value]: + values = extract_mlir_values(self.num_persistent_ctas) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self._num_tiles_executed)) + return values + + def __new_from_mlir_values__( + self, values: list[ir.Value] + ) -> "Mamba2SSDTileScheduler": + assert len(values) == 3 + new_num_persistent_ctas = new_from_mlir_values( + self.num_persistent_ctas, [values[0]] + ) + new_current_work_linear_idx = new_from_mlir_values( + self._current_work_linear_idx, [values[1]] + ) + new_num_tiles_executed = new_from_mlir_values( + self._num_tiles_executed, [values[2]] + ) + return Mamba2SSDTileScheduler( + self.params, + new_num_persistent_ctas, + new_current_work_linear_idx, + new_num_tiles_executed, + ) + + # called by host + @dsl_user_op + @staticmethod + def create( + params: Mamba2SSDTileSchedulerParams, + block_idx: Tuple[Integer, Integer, Integer], + grid_dim: Tuple[Integer, Integer, Integer], + *, + loc=None, + ip=None, + ): + params = params + + # Calculate the number of persistent clusters by dividing the total grid size + # by the number of CTAs per cluster + num_persistent_ctas = Int32(cute.size(grid_dim, loc=loc, ip=ip)) + + bidx, bidy, bidz = block_idx + + # Initialize workload index equals to the cluster index in the grid + current_work_linear_idx = Int32(bidx) + + # Initialize number of tiles executed to zero + num_tiles_executed = Int32(0) + return Mamba2SSDTileScheduler( + params, + num_persistent_ctas, + current_work_linear_idx, + num_tiles_executed, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Mamba2SSDTileSchedulerParams, + max_active_clusters: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Integer, Integer, Integer]: + return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip) + + # private method + def _get_current_work_for_linear_idx( + self, current_work_linear_idx: Int32, *, loc=None, ip=None + ) -> WorkTileInfo: + is_valid = current_work_linear_idx < cute.size( + self.params.problem_shape_ntiles, loc=loc, ip=ip + ) + + eh_idx = current_work_linear_idx % self.params.eh + b_idx = current_work_linear_idx // self.params.eh + g_idx = eh_idx // self.params.ngroup_ratio + # cur_tile_coord is (b_idx, eh_idx, g_idx) + cur_tile_coord = tuple(Int32(x) for x in (b_idx, eh_idx, g_idx)) + + return WorkTileInfo(cur_tile_coord, is_valid) + + @dsl_user_op + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + return self._get_current_work_for_linear_idx( + self._current_work_linear_idx, loc=loc, ip=ip + ) + + @dsl_user_op + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + return self.get_current_work(loc=loc, ip=ip) + + @dsl_user_op + def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None): + self._current_work_linear_idx += Int32(advance_count) * Int32( + self.num_persistent_ctas + ) + self._num_tiles_executed += Int32(1) + + @property + def num_tiles_executed(self) -> Int32: + return self._num_tiles_executed diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/cute/ffi/jit_argument.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/cute/ffi/jit_argument.py new file mode 100644 index 0000000000000000000000000000000000000000..acdb42efb997a943fbf3c442a985eda1bfa20e24 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/cute/ffi/jit_argument.py @@ -0,0 +1,305 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations. + +This example demonstrates a basic approach to building customized interfaces as C-structures between user code +and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions +and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions. + +The C-structure is defined as: + +.. code-block:: c + + struct Tensor { + void *ptr; // Pointer to tensor data + int32_t shape[3]; // Tensor dimensions + int32_t strides[3]; // Memory strides for each dimension + }; + +The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer, +shape, and strides, enabling efficient data passing between different language boundaries. + +.. note:: + Future development may include automated code generation flows. +""" + +import cutlass +import cutlass.cute as cute + +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm +import cutlass._mlir.extras.types as T + + +class ExampleTensorValue(ir.Value): + """A wrapper class for tensor values in MLIR. + + This class extends ir.Value to provide convenient access to tensor data pointer, + shape, and strides through MLIR operations. + + :type: ir.Value + """ + + def __init__(self, v): + """Initialize a new TensorValue. + + :param v: The underlying MLIR value to wrap + :type v: ir.Value + """ + super().__init__(v) + + @property + def data_ptr(self, *, loc=None, ip=None): + """Get the data pointer from the tensor value. + + Extracts the data pointer (first field) from the LLVM struct value. + + :param loc: Optional location information for MLIR operations + :type loc: Optional[ir.Location] + :param ip: Optional insertion point for MLIR operations + :type ip: Optional[ir.InsertionPoint] + :return: An integer value representing the data pointer + :rtype: ir.Value + """ + # Extract the data pointer from the LLVM struct value + # The data pointer is the first field (index 0) in the struct + + # Use llvm.extractvalue to get the pointer field from the struct + ptr_val = llvm.extractvalue( + llvm.PointerType.get(), + self, + [0], # Extract the first field (index 0) + loc=loc, + ip=ip, + ) + + return cute.make_ptr(cutlass.Float32, ptr_val) + + @property + def shape(self): + """Get the shape of the tensor. + + Extracts the shape (second field) from the LLVM struct value. + + :return: A tuple of integers representing the tensor dimensions + :rtype: tuple[ir.Value, ...] + """ + i32_type = ir.IntegerType.get_signless(32) + # Extract the shape field from the LLVM struct value + # The shape is the second field (index 1) in the struct + shape_val = llvm.extractvalue( + llvm.StructType.get_literal([i32_type] * 3), + self, + [1], # Extract the second field (index 1) + ) + + # Extract each dimension from the shape struct + return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3)) + + @property + def stride(self): + """Get the strides of the tensor. + + Extracts the strides (third field) from the LLVM struct value. + + :return: A tuple of integers representing the tensor strides + :rtype: tuple[ir.Value, ...] + """ + i32_type = ir.IntegerType.get_signless(32) + # Extract the strides field from the LLVM struct value + # The strides are the third field (index 2) in the struct + strides_val = llvm.extractvalue( + llvm.StructType.get_literal([i32_type] * 3), + self, + [2], # Extract the third field (index 2) + ) + + # Extract each dimension from the strides struct + return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3)) + + +class ExampleTensor: + """A class representing a tensor with its data pointer, shape, and strides. + + This class provides a Python interface to create and manipulate tensor structures + that can be passed to CUTE JIT compiled functions. + + :ivar _c_struct_p: The C struct pointer for the tensor + :ivar _rank: The number of dimensions in the tensor + """ + + def __init__(self, c_struct_p, rank): + """Initialize a new Tensor. + + :param c_struct_p: The C struct pointer for the tensor + :type c_struct_p: int + :param rank: The number of dimensions in the tensor + :type rank: int + """ + self._c_struct_p = c_struct_p + self._rank = rank + + def __get_mlir_types__(self): + """Get the MLIR types for this tensor. + + Creates an LLVM structure type representing a C-structure with: + + .. code-block:: c + + struct Tensor { + void *ptr; + int32_t shape[3]; + int32_t strides[3]; + }; + + :return: A list containing the MLIR struct type + :rtype: list[llvm.StructType] + + Create an LLVM structure type that represents a C-structure like: + """ + + # Get the number of dimensions from the shape + ndim = self._rank + + # Create the pointer type (void*) + ptr_type = llvm.PointerType.get() + + # Create array types for shape and strides (int32_t[ndim]) + int32_type = ir.IntegerType.get_signless(32) + shape_type = llvm.StructType.get_literal([int32_type] * ndim) + strides_type = llvm.StructType.get_literal([int32_type] * ndim) + + # Create the structure type + struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type]) + + return [struct_type] + + def __new_from_mlir_values__(self, values): + """Create a new TensorValue from MLIR values. + + :param values: A list of MLIR values + :type values: list[ir.Value] + :return: A new TensorValue instance + :rtype: TensorValue + """ + return ExampleTensorValue(values[0]) + + def __c_pointers__(self): + """Get the C pointers for this tensor. + + :return: A list containing the C struct pointer + :rtype: list[int] + """ + return [self._c_struct_p] + + +@cute.jit +def foo(tensor): + """Example JIT function that prints tensor information. + + :param tensor: A Tensor instance to print information about + :type tensor: Tensor + """ + cute.printf("data_ptr: {}", tensor.data_ptr) + cute.printf("shape: {}", tensor.shape) + cute.printf("stride: {}", tensor.stride) + + mA = cute.make_tensor( + tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride) + ) + cute.print_tensor(mA) + + +import sys +import os +import subprocess +import shutil +import tempfile +import torch + + +def run_test(tmpdir=None): + # Skip cleanup if user provides tmpdir + cleanup = tmpdir is None + # Initialize temporary build directory + tmpdir = tmpdir or tempfile.mkdtemp() + + try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + + subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True) + subprocess.run(["cmake", "--build", tmpdir], check=True) + + sys.path.append(tmpdir) + + from tensor import make_tensor, pycapsule_get_pointer + + # Mock test tensor and corresponding C structure for this example + # In production, this may come from external library + x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4) + c_struct = make_tensor(x.data_ptr(), x.shape, x.stride()) + c_struct_p = pycapsule_get_pointer(c_struct) + + # Initialize tensor wrapper and compile test function + tensor = ExampleTensor(c_struct_p, len(x.shape)) + compiled_func = cute.compile(foo, tensor) + + # Benchmark pointer access performance + from time import time + + start = time() + # Measure performance of critical path pointer access + # get C pointers is on critical path to call JIT compiled function + for _ in range(1000): + tensor.__c_pointers__() + end = time() + print(f"__c_pointers__: {(end - start) * 1000} us") + + # Execute compiled function + compiled_func(tensor) + except Exception as e: + print(e) + finally: + if cleanup: + # Clean up the temporary directory + shutil.rmtree(tmpdir) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Set temporary directory for building C modules" + ) + parser.add_argument( + "--tmp-dir", type=str, help="Temporary directory path for building C modules" + ) + args = parser.parse_args() + + run_test(args.tmp_dir) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/hopper/dense_gemm.py b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/hopper/dense_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..c59ace02d6bf9e9d96e1fa04a62c745386998e44 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/hopper/dense_gemm.py @@ -0,0 +1,1595 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Tuple, Type +import math +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack +import cutlass.utils.hopper_helpers as sm90_utils + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture +using CuTe DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Hopper's WGMMA for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using WGMMA instruction. +3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations. + +Hopper WGMMA instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Perform MMA operation and store the result in Accumulator(register) + +To run this example: + +.. code-block:: bash + + python examples/hopper/dense_gemm.py \ + --mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \ + --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major k --b_major k --c_major n + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape +is (1,1). The input, mma accumulator and output data type are set as fp16, fp32 +and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/hopper/dense_gemm.py \ + --mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \ + --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major k --b_major k --c_major n + +Constraints: +* Supported input data types: fp16, fp8 (e4m3fn, e5m2) +* For fp16 types, A and B must have the same data type +* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit +* Fp8 types only support k-major layout +* CTA tile shape M must be 64/128 +* CTA tile shape N must be 64/128/256 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 4 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively. +""" + + +# ///////////////////////////////////////////////////////////////////////////// +# Helpers to parse args +# ///////////////////////////////////////////////////////////////////////////// +def parse_comma_separated_ints(s: str): + try: + return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.") + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(4096, 4096, 4096, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--tile_shape_mn", + type=parse_comma_separated_ints, + choices=[(128, 128), (128, 256), (128, 64), (64, 64)], + default=(128, 128), + help="Cta tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + choices=[(1, 1), (2, 1), (1, 2), (2, 2)], + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument( + "--a_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--b_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--c_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + if len(args.tile_shape_mn) != 2: + parser.error("--tile_shape_mn must contain exactly 2 values") + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + return args + + +# ///////////////////////////////////////////////////////////////////////////// +# Host setup and device kernel launch +# ///////////////////////////////////////////////////////////////////////////// + + +class HopperWgmmaGemmKernel: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Hopper GPUs. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param tile_shape_mn: Shape of the CTA tile (M,N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: Data type requirements: + - For 16-bit types: A and B must have the same data type + - For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit + - Float8 types only support k-major layout + + :note: Supported data types: + - Float16 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulation types: + - Float32 (for all floating point inputs) + + :note: Constraints: + - CTA tile M must be 64/128 + - CTA tile N must be 64/128/256 + - CTA tile K must be 64 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 4 + + Example: + >>> gemm = HopperWgmmaGemmKernel( + ... acc_dtype=cutlass.Float32, + ... tile_shape_mn=(128, 256), + ... cluster_shape_mn=(1, 1) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + """ + + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + tile_shape_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + ): + """ + Initializes the configuration for a Hopper dense GEMM kernel. + + This configuration includes data types for operands, tile shape, cluster configuration, + and thread layout. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param tile_shape_mn: Shape of the CTA tile (M,N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype = acc_dtype + + self.cluster_shape_mn = cluster_shape_mn + self.mma_inst_shape_mn = None + # K dimension is deferred in _setup_attributes + self.tile_shape_mnk = (*tile_shape_mn, 1) + # For large tile size, using two warp groups is preferred because using only one warp + # group may result in register spill + self.atom_layout_mnk = ( + (2, 1, 1) + if self.tile_shape_mnk[0] > 64 and self.tile_shape_mnk[1] > 128 + else (1, 1, 1) + ) + self.num_mcast_ctas_a = None + self.num_mcast_ctas_b = None + self.is_a_mcast = False + self.is_b_mcast = False + self.tiled_mma = None + + self.occupancy = 1 + self.mma_warp_groups = math.prod(self.atom_layout_mnk) + self.num_threads_per_warp_group = 128 + self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90") + + self.ab_stage = None + self.epi_stage = None + + self.a_smem_layout_staged = None + self.b_smem_layout_staged = None + self.epi_smem_layout_staged = None + self.epi_tile = None + + self.shared_storage = None + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + """ + + # check the cta tile shape + if self.tile_shape_mnk[0] not in [64, 128]: + raise ValueError("CTA tile shape M must be 64/128") + if self.tile_shape_mnk[1] not in [64, 128, 256]: + raise ValueError("CTA tile shape N must be 64/128/256") + + self.tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_layout.sm90_mma_major_mode(), + self.b_layout.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + tiler_mn=(64, self.tile_shape_mnk[1]), + ) + mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.tile_shape_mnk = ( + self.tile_shape_mnk[0], + self.tile_shape_mnk[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_layout_mnk = cute.make_layout((*self.cluster_shape_mn, 1)) + self.num_mcast_ctas_a = self.cluster_shape_mn[1] + self.num_mcast_ctas_b = self.cluster_shape_mn[0] + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + is_cooperative = self.atom_layout_mnk == (2, 1, 1) + self.epi_tile = self._sm90_compute_tile_shape_or_override( + self.tile_shape_mnk, self.c_dtype, is_cooperative=is_cooperative + ) + + # Compute stage before compute smem layout + self.ab_stage, self.epi_stage = self._compute_stages( + self.tile_shape_mnk, + self.a_dtype, + self.b_dtype, + self.smem_capacity, + self.occupancy, + ) + + ( + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + ) = self._make_smem_layouts( + self.tile_shape_mnk, + self.epi_tile, + self.a_dtype, + self.a_layout, + self.b_dtype, + self.b_layout, + self.ab_stage, + self.c_dtype, + self.c_layout, + self.epi_stage, + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + """ + + # setup static attributes before smem/grid/tma computation + self.a_dtype = a.element_type + self.b_dtype = b.element_type + self.c_dtype = c.element_type + self.a_layout = utils.LayoutEnum.from_tensor(a) + self.b_layout = utils.LayoutEnum.from_tensor(b) + self.c_layout = utils.LayoutEnum.from_tensor(c) + + if cutlass.const_expr( + self.a_dtype.width == 16 and self.a_dtype != self.b_dtype + ): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + if cutlass.const_expr(self.a_dtype.width != self.b_dtype.width): + raise TypeError( + f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}" + ) + if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8): + raise TypeError(f"a_dtype should be float16 or float8") + + self._setup_attributes() + + tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( + a, + self.a_smem_layout_staged, + (self.tile_shape_mnk[0], self.tile_shape_mnk[2]), + self.cluster_shape_mn[1], + ) + + tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( + b, + self.b_smem_layout_staged, + (self.tile_shape_mnk[1], self.tile_shape_mnk[2]), + self.cluster_shape_mn[0], + ) + + tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors( + c, + self.epi_smem_layout_staged, + self.epi_tile, + ) + + grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mn) + + @cute.struct + class SharedStorage: + mainloop_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.ab_stage * 2 + ] + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + self.tiled_mma, + self.cta_layout_mnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tiled_mma: cute.TiledMma, + cta_layout_mnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + epi_smem_layout_staged: cute.ComposedLayout, + ): + """ + GPU device kernel performing the batched GEMM computation. + + :param tma_atom_a: TMA copy atom for A tensor + :type tma_atom_a: cute.CopyAtom + :param mA_mkl: Input tensor A + :type mA_mkl: cute.Tensor + :param tma_atom_b: TMA copy atom for B tensor + :type tma_atom_b: cute.CopyAtom + :param mB_nkl: Input tensor B + :type mB_nkl: cute.Tensor + :param tma_atom_c: TMA copy atom for C tensor + :type tma_atom_c: cute.CopyAtom + :param mC_mnl: Output tensor C + :type mC_mnl: cute.Tensor + :param tiled_mma: Tiled MMA object + :type tiled_mma: cute.TiledMma + :param cta_layout_mnk: CTA layout + :type cta_layout_mnk: cute.Layout + :param a_smem_layout_staged: Shared memory layout for A + :type a_smem_layout_staged: cute.ComposedLayout + :param b_smem_layout_staged: Shared memory layout for B + :type b_smem_layout_staged: cute.ComposedLayout + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + """ + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch Tma desc + # ///////////////////////////////////////////////////////////////////////////// + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) + + # /////////////////////////////////////////////////////////////////////////////// + # Get cta/warp/thread idx + # /////////////////////////////////////////////////////////////////////////////// + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + + cidx, cidy, _ = cute.arch.cluster_idx() + cdimx, cdimy, _ = cute.arch.cluster_dim() + cluster_id = cidx + cdimx * cidy + + # CTA Swizzle to promote L2 data reuse + group_size_m = 8 + s_shape = ( + (group_size_m, cdimx // group_size_m), + cdimy, + ) + s_stride = ((1, cdimy * group_size_m), group_size_m) + s_layout = cute.make_layout(s_shape, stride=s_stride) + num_reg_cids = cute.size(s_shape) + cid_m, cid_n = s_layout.get_flat_coord(cluster_id % num_reg_cids) + + # Deal with the tail part + if cluster_id >= num_reg_cids: + tail_size_m = cdimx % group_size_m + tail_layout = cute.make_layout( + (tail_size_m, cdimy), stride=(1, tail_size_m) + ) + tail_cid = cluster_id - num_reg_cids + tail_cid_m, tail_cid_n = tail_layout.get_flat_coord(tail_cid) + cid_m = cute.size(s_shape, mode=[0]) + tail_cid_m + cid_n = tail_cid_n + + # Get the pid from cluster id + bidx_in_cluster = cute.arch.block_in_cluster_idx() + pid_m = cid_m * self.cluster_shape_mn[0] + bidx_in_cluster[0] + pid_n = cid_n * self.cluster_shape_mn[1] + bidx_in_cluster[1] + + tile_coord_mnkl = (pid_m, pid_n, None, bidz) + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) + + # /////////////////////////////////////////////////////////////////////////////// + # Get mcast mask + # /////////////////////////////////////////////////////////////////////////////// + a_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=1 + ) + b_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=0 + ) + + a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 + b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0)) + tma_copy_bytes = cute.size_in_bytes( + self.a_dtype, a_smem_layout + ) + cute.size_in_bytes(self.b_dtype, b_smem_layout) + + # ///////////////////////////////////////////////////////////////////////////// + # Alloc and init AB full/empty + ACC full mbar (pipeline) + # ///////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # mbar arrays + mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr() + + # Threads/warps participating in this pipeline + mainloop_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread + ) + # Each warp will constribute to the arrive count with the number of mcast size + mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + num_warps = self.threads_per_cta // 32 + consumer_arrive_cnt = mcast_size * num_warps + mainloop_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + + cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape)) + mainloop_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=mainloop_pipeline_array_ptr, + num_stages=self.ab_stage, + producer_group=mainloop_pipeline_producer_group, + consumer_group=mainloop_pipeline_consumer_group, + tx_count=tma_copy_bytes, + cta_layout_vmnk=cta_layout_vmnk, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # /////////////////////////////////////////////////////////////////////////////// + # Generate smem tensor A/B + # /////////////////////////////////////////////////////////////////////////////// + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + sC_ptr = cute.recast_ptr( + sA.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype + ) + sC = cute.make_tensor(sC_ptr, epi_smem_layout_staged.outer) + + # /////////////////////////////////////////////////////////////////////////////// + # Local_tile partition global tensors + # /////////////////////////////////////////////////////////////////////////////// + # (bM, bK, RestK) + gA_mkl = cute.local_tile( + mA_mkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1) + ) + # (bN, bK, RestK) + gB_nkl = cute.local_tile( + mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1) + ) + # (bM, bN) + gC_mnl = cute.local_tile( + mC_mnl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None) + ) + + # ////////////////////////////////////////////////////////////////////////////// + # Partition global tensor for TiledMMA_A/B/C + # ////////////////////////////////////////////////////////////////////////////// + warp_group_idx = cute.arch.make_warp_uniform( + tidx // self.num_threads_per_warp_group + ) + warp_group_thread_layout = cute.make_layout( + self.mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)) + + tCgC = thr_mma.partition_C(gC_mnl) + + # ////////////////////////////////////////////////////////////////////////////// + # Partition shared tensor for TMA load A/B + # ////////////////////////////////////////////////////////////////////////////// + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) + a_cta_crd = cluster_coord_mnk[1] + sA_for_tma_partition = cute.group_modes(sA, 0, 2) + gA_for_tma_partition = cute.group_modes(gA_mkl, 0, 2) + tAsA, tAgA_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_a, + a_cta_crd, + a_cta_layout, + sA_for_tma_partition, + gA_for_tma_partition, + ) + + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) + b_cta_crd = cluster_coord_mnk[0] + sB_for_tma_partition = cute.group_modes(sB, 0, 2) + gB_for_tma_partition = cute.group_modes(gB_nkl, 0, 2) + tBsB, tBgB_nkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_b, + b_cta_crd, + b_cta_layout, + sB_for_tma_partition, + gB_for_tma_partition, + ) + + # ////////////////////////////////////////////////////////////////////////////// + # Make fragments + # ////////////////////////////////////////////////////////////////////////////// + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCrA = tiled_mma.make_fragment_A(tCsA) + tCrB = tiled_mma.make_fragment_B(tCsB) + + acc_shape = tCgC.shape + accumulators = cute.make_fragment(acc_shape, self.acc_dtype) + + # /////////////////////////////////////////////////////////////////////////////// + # Cluster wait + # /////////////////////////////////////////////////////////////////////////////// + # cluster wait for barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.sync_threads() + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch + # ///////////////////////////////////////////////////////////////////////////// + k_tile_cnt = cute.size(gA_mkl, mode=[2]) + prefetch_k_tile_cnt = cutlass.max(cutlass.min(self.ab_stage, k_tile_cnt), 0) + + mainloop_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + if warp_idx == 0: + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch TMA load + # ///////////////////////////////////////////////////////////////////////////// + for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1): + # ///////////////////////////////////////////////////////////////////////////// + # Wait for A/B buffers to be empty before loading into them + # Also sets the transaction barrier for the A/B buffers + # ///////////////////////////////////////////////////////////////////////////// + mainloop_pipeline.producer_acquire(mainloop_producer_state) + # ///////////////////////////////////////////////////////////////////////////// + # Slice to global/shared memref to current k_tile + # ///////////////////////////////////////////////////////////////////////////// + tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] + tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] + + tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)] + tBsB_pipe = tBsB[(None, mainloop_producer_state.index)] + + # ///////////////////////////////////////////////////////////////////////////// + # TMA load A/B + # ///////////////////////////////////////////////////////////////////////////// + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=a_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=b_mcast_mask, + ) + # Mainloop pipeline's producer commit is a NOP + mainloop_pipeline.producer_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + # ///////////////////////////////////////////////////////////////////////////// + # Prologue MMAs + # ///////////////////////////////////////////////////////////////////////////// + k_pipe_mmas = 1 + + mainloop_consumer_read_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + mainloop_consumer_release_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + + peek_ab_full_status = cutlass.Boolean(1) + if mainloop_consumer_read_state.count < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_read_state + ) + + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_tile in cutlass.range_constexpr(k_pipe_mmas): + # Wait for A/B buffer to be ready + mainloop_pipeline.consumer_wait( + mainloop_consumer_read_state, peek_ab_full_status + ) + + cute.nvgpu.warpgroup.fence() + for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + tCrA_1phase = tCrA[k_block_coord] + tCrB_1phase = tCrB[k_block_coord] + + cute.gemm( + tiled_mma, + accumulators, + tCrA_1phase, + tCrB_1phase, + accumulators, + ) + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + + cute.nvgpu.warpgroup.commit_group() + mainloop_consumer_read_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if mainloop_consumer_read_state.count < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_read_state + ) + + # ///////////////////////////////////////////////////////////////////////////// + # MAINLOOP + # ///////////////////////////////////////////////////////////////////////////// + for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1): + # ///////////////////////////////////////////////////////////////////////////// + # Wait for TMA copies to complete + # ///////////////////////////////////////////////////////////////////////////// + mainloop_pipeline.consumer_wait( + mainloop_consumer_read_state, peek_ab_full_status + ) + # ///////////////////////////////////////////////////////////////////////////// + # WGMMA + # ///////////////////////////////////////////////////////////////////////////// + cute.nvgpu.warpgroup.fence() + for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + tCrA_1phase = tCrA[k_block_coord] + tCrB_1phase = tCrB[k_block_coord] + + cute.gemm( + tiled_mma, + accumulators, + tCrA_1phase, + tCrB_1phase, + accumulators, + ) + + cute.nvgpu.warpgroup.commit_group() + # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete + cute.nvgpu.warpgroup.wait_group(k_pipe_mmas) + + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + + mainloop_consumer_read_state.advance() + mainloop_consumer_release_state.advance() + + peek_ab_full_status = cutlass.Boolean(1) + if mainloop_consumer_read_state.count < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_read_state + ) + # ///////////////////////////////////////////////////////////////////////////// + # TMA load + # ///////////////////////////////////////////////////////////////////////////// + if warp_idx == 0 and mainloop_producer_state.count < k_tile_cnt: + # ///////////////////////////////////////////////////////////////////////////// + # Wait for A/B buffers to be empty before loading into them + # Also sets the transaction barrier for the A/B buffers + # ///////////////////////////////////////////////////////////////////////////// + mainloop_pipeline.producer_acquire(mainloop_producer_state) + + # ///////////////////////////////////////////////////////////////////////////// + # Slice to global/shared memref to current k_tile + # ///////////////////////////////////////////////////////////////////////////// + tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] + tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] + + tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)] + tBsB_pipe = tBsB[(None, mainloop_producer_state.index)] + + # ///////////////////////////////////////////////////////////////////////////// + # TMA load A/B + # ///////////////////////////////////////////////////////////////////////////// + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=a_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=b_mcast_mask, + ) + # Mainloop pipeline's producer commit is a NOP + mainloop_pipeline.producer_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + # ///////////////////////////////////////////////////////////////////////////// + # EPILOG + # ///////////////////////////////////////////////////////////////////////////// + cute.nvgpu.warpgroup.wait_group(0) + + if cute.size(self.cluster_shape_mn) > 1: + # Wait for all threads in the cluster to finish, avoid early release of smem + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + else: + # For cluster that has a single thread block, it might have more than one warp groups. + # Wait for all warp groups in the thread block to finish, because smem for tensor A in + # the mainloop is reused in the epilogue. + cute.arch.sync_threads() + + copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + self.c_layout, + elem_ty_d=self.c_dtype, + elem_ty_acc=self.acc_dtype, + ) + + copy_atom_C = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp( + self.c_layout.is_m_major_c(), + 4, + ), + self.c_dtype, + ) + + tiled_copy_C_Atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + + tiled_copy_r2s = cute.make_tiled_copy_S( + copy_atom_r2s, + tiled_copy_C_Atom, + ) + + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sD = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rAcc = tiled_copy_r2s.retile(accumulators) + + # Allocate D registers. + rD_shape = cute.shape(thr_copy_r2s.partition_S(sC)) + tRS_rD_layout = cute.make_layout(rD_shape[:3]) + tRS_rD = cute.make_fragment_like(tRS_rD_layout, self.acc_dtype) + size_tRS_rD = cute.size(tRS_rD) + + sepi_for_tma_partition = cute.group_modes(sC, 0, 2) + tCgC_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile) + + bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sepi_for_tma_partition, + tCgC_for_tma_partition, + ) + + epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1]) + epi_tile_shape = tCgC_for_tma_partition.shape[1] + epi_tile_layout = cute.make_layout( + epi_tile_shape, stride=(epi_tile_shape[1], 1) + ) + + # Initialize tma store c_pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.epi_stage, + producer_group=c_producer_group, + ) + + for epi_idx in cutlass.range_constexpr(epi_tile_num): + # Copy from accumulators to D registers + for epi_v in cutlass.range_constexpr(size_tRS_rD): + tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] + + # Type conversion + tRS_rD_out = cute.make_fragment_like(tRS_rD_layout, self.c_dtype) + acc_vec = tRS_rD.load() + tRS_rD_out.store(acc_vec.to(self.c_dtype)) + + # Copy from D registers to shared memory + epi_buffer = epi_idx % cute.size(tRS_sD, mode=[3]) + cute.copy( + tiled_copy_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)] + ) + + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # barrier for sync + cute.arch.barrier() + + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + # Copy from shared memory to global memory + if warp_idx == 0: + cute.copy( + tma_atom_c, + bSG_sD[(None, epi_buffer)], + bSG_gD[(None, gmem_coord)], + ) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + + cute.arch.barrier() + + if warp_idx == 0: + c_pipeline.producer_tail() + + return + + @staticmethod + def _compute_stages( + tile_shape_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + smem_capacity: int, + occupancy: int, + ) -> tuple[int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type tile_shape_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (A/B operand stages, epilogue stages) + :rtype: tuple[int, int] + """ + + epi_stage = 4 + # epi_smem will reuse smem ab. + epi_bytes = 0 + + a_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + b_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + ab_bytes_per_stage = ( + cute.size(a_shape) * a_dtype.width // 8 + + cute.size(b_shape) * b_dtype.width // 8 + ) + mbar_helpers_bytes = 1024 + + ab_stage = ( + smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes + ) // ab_bytes_per_stage + return ab_stage, epi_stage + + @staticmethod + def _sm90_compute_tile_shape_or_override( + tile_shape_mnk: tuple[int, int, int], + element_type: type[cutlass.Numeric], + is_cooperative: bool = False, + epi_tile_override: tuple[int, int] | None = None, + ) -> tuple[int, int]: + """Compute the epilogue tile shape or use override if provided. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param element_type: Data type of elements + :type element_type: type[cutlass.Numeric] + :param is_cooperative: Whether to use cooperative approach + :type is_cooperative: bool + :param epi_tile_override: Optional override for epilogue tile shape + :type epi_tile_override: Tuple[int, int] or None + + :return: Computed epilogue tile shape + :rtype: Tuple[int, int] + """ + if epi_tile_override is not None: + return epi_tile_override + if is_cooperative: + tile_m = min(128, cute.size(tile_shape_mnk, mode=[0])) + tile_n = min(32, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + else: + n_perf = 64 if element_type.width == 8 else 32 + tile_m = min(64, cute.size(tile_shape_mnk, mode=[0])) + tile_n = min(n_perf, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + + @staticmethod + def _make_smem_layouts( + tile_shape_mnk: tuple[int, int, int], + epi_tile: tuple[int, int], + a_dtype: type[cutlass.Numeric], + a_layout: utils.LayoutEnum, + b_dtype: type[cutlass.Numeric], + b_layout: utils.LayoutEnum, + ab_stage: int, + c_dtype: type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + epi_stage: int, + ) -> tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]: + """Create shared memory layouts for A, B, and C tensors. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + :param a_dtype: Data type for matrix A + :type a_dtype: type[cutlass.Numeric] + :param a_layout: Layout enum for matrix A + :type a_layout: utils.LayoutEnum + :param b_dtype: Data type for matrix B + :type b_dtype: type[cutlass.Numeric] + :param b_layout: Layout enum for matrix B + :type b_layout: utils.LayoutEnum + :param ab_stage: Number of stages for A/B tensors + :type ab_stage: int + :param c_dtype: Data type for output matrix C + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum for the output matrix C + :type c_layout: utils.LayoutEnum + :param epi_stage: Number of epilogue stages + :type epi_stage: int + + :return: Tuple of shared memory layouts for A, B, and C + :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout] + """ + a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + + a_is_k_major = ( + a_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K + ) + b_is_k_major = ( + b_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K + ) + a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0] + a_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + a_layout, + a_dtype, + a_major_mode_size, + ), + a_dtype, + ) + a_smem_layout_staged = cute.tile_to_shape( + a_smem_layout_atom, + cute.append(a_smem_shape, ab_stage), + order=(0, 1, 2) if a_is_k_major else (1, 0, 2), + ) + + b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + + b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1] + b_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + b_layout, + b_dtype, + b_major_mode_size, + ), + b_dtype, + ) + b_smem_layout_staged = cute.tile_to_shape( + b_smem_layout_atom, + cute.append(b_smem_shape, ab_stage), + order=(0, 1, 2) if b_is_k_major else (1, 0, 2), + ) + + c_smem_shape = epi_tile + c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0] + c_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + c_layout, + c_dtype, + c_major_mode_size, + ), + c_dtype, + ) + epi_smem_layout_staged = cute.tile_to_shape( + c_smem_layout_atom, + cute.append(c_smem_shape, epi_stage), + order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2), + ) + + return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged + + @staticmethod + def _compute_grid( + c: cute.Tensor, + tile_shape_mnk: tuple[int, int, int], + cluster_shape_mn: tuple[int, int], + ) -> tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + c_shape = (tile_shape_mnk[0], tile_shape_mnk[1]) + gc = cute.zipped_divide(c, tiler=c_shape) + cluster_shape_mnl = (*cluster_shape_mn, 1) + clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnl) + grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnl)) + return grid + + @staticmethod + def _make_tma_store_atoms_and_tensors( + tensor_c: cute.Tensor, + epi_smem_layout_staged: cute.ComposedLayout, + epi_tile: tuple[int, int], + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for C tensor storage. + + :param tensor_c: Output tensor C + :type tensor_c: cute.Tensor + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + + :return: TMA atom and tensor for C + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) + c_cta_v_layout = cute.composition( + cute.make_identity_layout(tensor_c.shape), epi_tile + ) + tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom( + cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(), + tensor_c, + epi_smem_layout, + c_cta_v_layout, + ) + + return tma_atom_c, tma_tensor_c + + @staticmethod + def _make_tma_atoms_and_tensors( + tensor: cute.Tensor, + smem_layout_staged: cute.ComposedLayout, + smem_tile: tuple[int, int], + mcast_dim: int, + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for input tensors. + + :param tensor: Input tensor (A or B) + :type tensor: cute.Tensor + :param smem_layout_staged: Shared memory layout for the tensor + :type smem_layout_staged: cute.ComposedLayout + :param smem_tile: Shared memory tile shape + :type smem_tile: Tuple[int, int] + :param mcast_dim: Multicast dimension + :type mcast_dim: int + + :return: TMA atom and tensor + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + op = ( + cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() + if mcast_dim == 1 + else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp() + ) + + smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) + tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + return tma_atom, tma_tensor + + @staticmethod + def is_valid_dtypes( + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + ) -> bool: + """ + Check if the dtypes are valid + + :param a_dtype: The data type of tensor A + :type a_dtype: Type[cutlass.Numeric] + :param b_dtype: The data type of tensor B + :type b_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: major mode of tensor A + :type a_major: str + :param b_major: major mode of tensor B + :type b_major: str + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + # tested a_dtype + if a_dtype not in { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # tested b_dtype + if b_dtype not in { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # tested acc_dtype + if acc_dtype not in {cutlass.Float32, cutlass.Float16}: + is_valid = False + # tested c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # make sure a_dtype == b_dtype for Float16 + if a_dtype.width == 16 and a_dtype != b_dtype: + is_valid = False + # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2) + if a_dtype.width != b_dtype.width: + is_valid = False + + # for Float8 types, this implementation only supports k-major layout + if (a_dtype.width == 8 and a_major != "k") or ( + b_dtype.width == 8 and b_major != "k" + ): + is_valid = False + + return is_valid + + +def run( + mnkl: Tuple[int, int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + tile_shape_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param a_dtype: Data type for input tensor A + :type a_dtype: Type[cutlass.Numeric] + :param b_dtype: Data type for input tensor B + :type b_dtype: Type[cutlass.Numeric] + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param acc_dtype: Data type for accumulation during matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param tile_shape_mn: CTA tile shape (M, N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster shape (M, N) + :type cluster_shape_mn: Tuple[int, int] + :param tolerance: Tolerance value for reference validation comparison + :type tolerance: float + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + + print(f"Running Hopper Dense GEMM with:") + print(f"mnkl: {mnkl}") + print( + f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Tile Shape: {tile_shape_mn}, Cluster Shape: {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported types + if not HopperWgmmaGemmKernel.is_valid_dtypes( + a_dtype, b_dtype, acc_dtype, c_dtype, a_major, b_major + ): + raise TypeError( + f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {c_dtype}, {a_major=}, {b_major=}" + ) + + # Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else : (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass.torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor + + a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype) + b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype) + c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype) + + gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mn, cluster_shape_mn) + + torch_stream = torch.cuda.Stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + # compile gemm kernel + compiled_gemm = cute.compile(gemm, mA, mB, mC, stream) + + if not skip_ref_check: + # execution + compiled_gemm(mA, mB, mC, stream) + + torch.cuda.synchronize() + + # Ref check + ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu() + + if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2): + # m major: (l, n, m) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03) + + def generate_tensors(): + _, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype) + _, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype) + _, mC_workspace, _ = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype) + return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch.numel() * a_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + c_torch.numel() * c_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + args = parse_arguments() + run( + args.mnkl, + args.a_dtype, + args.b_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.tile_shape_mn, + args.cluster_shape_mn, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/axpby.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/axpby.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60d5b468e2f28dd917f2776a88451385aa237d1c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/axpby.hpp @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template > +CUTE_HOST_DEVICE +void +axpby(Alpha const& alpha, + Tensor const& x, + Beta const& beta, + Tensor && y, + PrdTensor const& p = {}) +{ + return axpby(alpha, x, beta, y, p); +} + +// +// AXPBY +// +template > +CUTE_HOST_DEVICE +void +axpby(Alpha const& alpha, + Tensor const& x, + Beta const& beta, + Tensor & y, + PrdTensor const& p = {}) +{ + auto isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + + CUTE_GCC_UNREACHABLE; + } (); + + CUTE_UNROLL + for (int i = 0; i < size(x); ++i) { + if (p(i)) { + y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i)); + } + } +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/clear.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/clear.hpp new file mode 100644 index 0000000000000000000000000000000000000000..225c46ec878215ccfae57504021df5352fbf10dd --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/clear.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::fill + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +clear(Tensor&& tensor) +{ + return clear(tensor); +} + +// +// Set elements to zero +// +template +CUTE_HOST_DEVICE +void +clear(Tensor& tensor) +{ + using T = typename Tensor::value_type; + + fill(tensor, T{}); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_copy.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_copy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2653916e3ce3eff3a26c44f77c1d67f8f92d0282 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_copy.hpp @@ -0,0 +1,338 @@ +/*************************************************************************************************** +* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +#pragma once + +#include +#include +#include // cute::logical_divide +#include // cute::Swizzle +#include // cute::get_nonswizzle_portion +#include // cute::Tensor +#include +#include + +namespace cute +{ + +template +CUTE_HOST_DEVICE void +naive_cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst) +{ + auto N = size(dst); + auto R = N % Int{}; + if (R > 0 && tid < R) { // Likely static condition && Residue in-bounds + dst[tid] = src[tid]; + } + CUTE_UNROLL + for (uint32_t i = uint32_t(R); i < uint32_t(N); i += NumThreads) { // All in-bounds + dst[tid + i] = src[tid + i]; + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE void +naive_cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst) +{ + return naive_cooperative_copy(tid, src, dst); +} + +// A heuristic to determine a "good" permutation of two tensors for later vectorization and thr-assignment +template +CUTE_HOST_DEVICE constexpr +auto +heuristic_permutation(Tensor const& a, + Tensor const& b) +{ + constexpr bool swizzleA = get_swizzle_t::num_bits != 0 or + get_swizzle_t::num_bits != 0; + constexpr bool swizzleB = get_swizzle_t::num_bits != 0 or + get_swizzle_t::num_bits != 0; + auto a_inv = right_inverse(get_nonswizzle_portion(a.layout())); + auto b_inv = right_inverse(get_nonswizzle_portion(b.layout())); + + constexpr uint8_t scoreA = (uint8_t(swizzleA) << 2) | + (uint8_t(is_smem::value) << 1) | + (uint8_t(size(a_inv) > size(b_inv)) << 0); + + constexpr uint8_t scoreB = (uint8_t(swizzleB) << 2) | + (uint8_t(is_smem::value) << 1) | + (uint8_t(size(b_inv) > size(a_inv)) << 0); + + if constexpr (scoreA >= scoreB) { + return a_inv; + } else { + return b_inv; + } +} + +// cooperative_copy(thr_idx, src, dst) +// Use NumThreads to copy Tensor src to Tensor dst with element-wise vectorization up to MaxVecBits. +// @pre 0 <= @a tid < NumThreads +// @pre Tensors @a src and @a dst are aligned up to MaxVecBits. +// That is, pointers and dynamic strides are assumed to be aligned up to MaxVecBits. +// +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst, + CopyPolicy const& cpy = {}) +{ + // Assumes the shapes are static, can generalize/fallback + CUTE_STATIC_ASSERT_V(is_static{} && is_static{}); + CUTE_STATIC_ASSERT_V(size(src) == size(dst)); + // Assumes the types are the same, can generalize/fallback + static_assert(cute::is_same::value); + static_assert(MaxVecBits == sizeof_bits_v || + MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128, + "Expected MaxVecBits to be value size or 8 or 16 or 32 or 64 or 128 for alignment and performance."); + // Check that the tensors are likely shared across threads: either gmem or smem + static_assert((is_gmem::value || is_smem::value), + "cooperative_copy expects shared gmem or smem source tensor."); + static_assert((is_gmem::value || is_smem::value), + "cooperative_copy expects shared gmem or smem destination tensor."); + // Precondition on tid in DEBUG + assert(tid < NumThreads); + // Precondition on pointer alignment in DEBUG + assert(is_byte_aligned(raw_pointer_cast(src.data()))); + assert(is_byte_aligned(raw_pointer_cast(dst.data()))); + +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy\n"); + print(" "); print("NumThreads: "); print(NumThreads); print("\n"); + print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n"); + print(" "); print("src: "); print(src); print("\n"); + print(" "); print("dst: "); print(dst); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + // The common layout of the two tensors that can be vectorized over elements and threads + // vidx -> coord + auto common_layout = heuristic_permutation(src, dst); + + // Apply + // (V, rest) + Tensor src_a = coalesce(logical_divide(src, common_layout), Shape<_1,_1>{}); + Tensor dst_a = coalesce(logical_divide(dst, common_layout), Shape<_1,_1>{}); + + // + // Determine vectorization of elems and thrs based on src/dst size and number of threads + // NOTE: This heuristic promotes parallelization over vectorization + // + + // The number of elements and number of bits + constexpr int elem_bits = sizeof_bits_v; + constexpr int total_elem = size(SrcLayout{}); + + // The number of elements that can be vectorized in values + constexpr int common_elem = decltype(max_common_vector(src_a, dst_a))::value; + +#if 0 + if (thread0()) { + print(" "); print("common_layout: "); print(common_layout); print("\n"); + print(" "); print("src_a: "); print(src_a); print("\n"); + print(" "); print("dst_a: "); print(dst_a); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + // + if constexpr (total_elem % NumThreads != 0) { + // Not attempting to find a partitioning pattern, fallback to dynamically indexed slowpath + + if constexpr (common_elem > 1 && MaxVecBits > elem_bits) { + // If the vectorization is non-trivial and divides the maximum vectorizations, then vectorize + constexpr auto max_align_src = elem_bits * decltype(max_alignment(src_a.layout()))::value; + constexpr auto max_align_dst = elem_bits * decltype(max_alignment(dst_a.layout()))::value; + constexpr auto vec_bits = gcd(max_align_src, max_align_dst, MaxVecBits); + using VecType = uint_bit_t; + + static_assert(vec_bits % elem_bits == 0, "Expected divisibility"); + static_assert((vec_bits >= 8), "No support for subbyte copying"); + + Tensor src_v = recast(src_a); + Tensor dst_v = recast(dst_a); + +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy -- naive\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + naive_cooperative_copy(tid, src_v, dst_v); + } else { + naive_cooperative_copy(tid, src_a, dst_a); + } + } else { + // If the tensors can be equally partitioned by the threads, + // compute vectorization widths in elements and threads. + + // If there are too many threads to allow a full vectorized copy, trunc the vectorization + constexpr int total_bits = total_elem * elem_bits; + constexpr int max_bits_per_thr = total_bits / NumThreads; + // At least elem_bits, at most common_bits + constexpr int common_bits = common_elem * elem_bits; + constexpr int vec_bits = cute::max(elem_bits, cute::gcd(common_bits, int(MaxVecBits), max_bits_per_thr)); + + // Should account for vec_bits < 8 and/or vec_elem <= 1 + // And also account for subbyte types, which could cause race conditions + // Want to ENFORCE sufficient vectorization in those cases + static_assert(vec_bits % elem_bits == 0, "Expected divisibility"); + static_assert(vec_bits >= 8, "No support for subbyte copying"); + + using VecType = uint_bit_t; + constexpr int vec_elem = vec_bits / elem_bits; + + constexpr int vec_thrs = cute::min(int(NumThreads), total_elem / vec_elem); + + // + // Determine the partitioning patterns for the vec_elems and vec_thrs + // + + // Distribute the rest of the V*T to some consistent portion outside of the common_layout, if needed + auto common_domain_src = domain_distribute(shape(src_a), Int{}); + auto common_domain_dst = domain_distribute(shape(dst_a), Int{}); + + // Make sure for now, could fall back here instead + CUTE_STATIC_ASSERT_V(size(common_domain_src) == Int{}); + CUTE_STATIC_ASSERT_V(compatible(common_domain_src, common_domain_dst) || + compatible(common_domain_dst, common_domain_src)); + // Use the "more specific" domain for the extra elements of V*T + auto common_domain = conditional_return(compatible(common_domain_src, common_domain_dst), + common_domain_dst, common_domain_src); + + // Construct the tiler + auto tiler_vt = common_domain.with_shape(Int{}, Int{}); + + // Apply and slice + Tensor src_v = logical_divide(src_a, tiler_vt)(make_coord(_,tid),_); + Tensor dst_v = logical_divide(dst_a, tiler_vt)(make_coord(_,tid),_); + +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy -- vec\n"); + print(" "); print("Used vector: "); print(vec_elem); print("\n"); + print(" "); print("Used threads: "); print(vec_thrs); print("\n"); + print(" "); print("tiler_vt: "); print(tiler_vt); print("\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + print(" "); print("recast(src_v): "); print(recast(src_v)); print("\n"); + print(" "); print("recast(dst_v): "); print(recast(dst_v)); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + // If we're using all threads (static) or the tid is in-range (dynamic) + if (vec_thrs == NumThreads or tid < vec_thrs) { + auto src_c = recast(src_v); + auto dst_c = recast(dst_v); + return copy(cpy, src_c, dst_c); + } + } +} + + +// Default max-vectorization size to value_type size +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst, + CopyPolicy const& cpy = {}) +{ + constexpr uint32_t MaxVecBits = sizeof_bits_v; + return cooperative_copy(tid, src, dst, cpy); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst, + CopyPolicy const& cpy = {}) +{ + return cooperative_copy(tid, src, dst, cpy); +} + +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst, + CopyPolicy const& cpy = {}) +{ + return cooperative_copy(tid, src, dst, cpy); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_gemm.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f0a993593d7d90eae9905f0f5bbbf890dcac2873 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_gemm.hpp @@ -0,0 +1,553 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +#include +#include +#include + +#include + +namespace cute +{ + +// +// Cooperative Shared-Memory GEMMs +// + +namespace detail { + +// Slow fallback path: +template +CUTE_HOST_DEVICE +void +epilogue_predication(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor & tCrC, + Beta const& beta, + Tensor & sC, + Tensor & tCsC, + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +{ + using InputTypeC = typename TSC::value_type; + using ComputeTypeC = typename ThrMMA::ValTypeC; + CUTE_STATIC_ASSERT(CUTE_STL_NAMESPACE::is_same_v); + + // Create coordinate tensors for the problem + Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n) + // Repeat partitioning with thr_mma + Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) + + const bool isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + // Custom axpby_if for now + CUTE_UNROLL + for (int i = 0; i < size(tCrC); ++i) + { + if (elem_less(tCcC(i), shape(sC))) + { + tCsC(i) = sC_store_op(isBetaZero ? alpha * tCrC(i) + : alpha * tCrC(i) + + beta * static_cast(sC_load_op(tCsC(i)))); + } + } +} + +template +CUTE_HOST_DEVICE +void +epilogue_no_predication(uint32_t thread_idx, + ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor & tCrC, + Beta const& beta, + Tensor & sC, + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C + SmemCopyLdOpC const& sC_copy_ld_op, + SmemCopyStOpC const& sC_copy_st_op) +{ + using InputTypeC = typename TSC::value_type; + using ComputeTypeC = typename TRC::value_type; + + const bool isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + Tensor tCrD = make_fragment_like(tCrC); + Tensor tCrDi = make_fragment_like(tCrD); + + if(!isBetaZero) { + auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom{}, thr_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx); + Tensor tCsC = smem_thr_copy_C.partition_S(sC); + Tensor tCrDi_copy_view = smem_thr_copy_C.retile_D(tCrDi); + CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N + copy(smem_tiled_copy_C, tCsC, tCrDi_copy_view); + + // Transform C on/after load + cute::transform(tCrDi, tCrD, sC_load_op); + } + // C = alpha * (A * B) + beta * C + axpby(alpha, tCrC, beta, tCrD); + // Transform C before/on store + cute::transform(tCrD, tCrDi, sC_store_op); + + auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom{}, thr_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx); + Tensor tCsC = smem_thr_copy_C.partition_D(sC); + Tensor tCrDi_copy_view = smem_thr_copy_C.retile_S(tCrDi); + CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N + copy(smem_tiled_copy_C, tCrDi_copy_view, tCsC); +} + +// Predicated Cooperative GEMM +template +CUTE_HOST_DEVICE +void +cooperative_gemm_predication(ThrMMA const& thr_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op) // transforms B values before use in GEMM +{ + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename ThrMMA::ValTypeA; + using ComputeTypeB = typename ThrMMA::ValTypeB; + using ComputeTypeC = typename ThrMMA::ValTypeC; + + // + // MMA Partitioning + // + + // Partition the sA, sB, and sC tiles across the threads for the MMA + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) + + // Create register tensors for the MMA to operate on + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) + // + // PREDICATION + // + + // Create coordinate tensors for the problem + Tensor cA = make_identity_tensor(shape(sA)); // (M,K) -> (m,k) + Tensor cB = make_identity_tensor(shape(sB)); // (N,K) -> (n,k) + + // Repeat partitioning with thr_mma + Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k) + Tensor tCcB = thr_mma.partition_B(cB); // (MMA,MMA_N,MMA_K) -> (n,k) + + // Allocate the preds for MMA- and MMA_MN-modes + Tensor tCpA = make_tensor(make_shape(size<0>(tCsA), size<1>(tCsA))); + Tensor tCpB = make_tensor(make_shape(size<0>(tCsB), size<1>(tCsB))); + // Populate the predicates on M and N + CUTE_UNROLL + for (int i = 0; i < size(tCpA); ++i) { + tCpA(i) = elem_less(get<0>(tCcA(_,_,Int<0>{})(i)), shape<0>(sA)); + } + CUTE_UNROLL + for (int i = 0; i < size(tCpB); ++i) { + tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB)); + } + // + // PREFETCH k_block = 0 + // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block + // Assumes the MMA-tiling in K is trivial + // + + constexpr int K_BLOCK_MAX = size<2>(tCrA); + + CUTE_UNROLL + for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I + tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? static_cast(sA_load_op(tCsA(i,m,0))) : ComputeTypeA{}; + } + } + CUTE_UNROLL + for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I + tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? static_cast(sB_load_op(tCsB(i,n,0))) : ComputeTypeB{}; + } + } + // + // MAINLOOP + // + + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + if (k_block < K_BLOCK_MAX-1) // static-if not the last k_block + { + int k_next = k_block + 1; // Load k_next block + + // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block + // Assumes the MMA-tiling in K is trivial + + CUTE_UNROLL + for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I + tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? static_cast(sA_load_op(tCsA(i,m,k_next))) : ComputeTypeA{}; + } + } + CUTE_UNROLL + for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I + tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? static_cast(sB_load_op(tCsB(i,n,k_next))) : ComputeTypeB{}; + } + } + } + // GEMM on k_block in registers + gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } +} + +// Unpredicated Cooperative GEMM +template +CUTE_HOST_DEVICE +void +cooperative_gemm_no_predication(uint32_t thread_idx, + ThrMMA const& thr_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + SmemCopyOpA const& sA_copy_op, + SmemCopyOpB const& sB_copy_op) +{ + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename ThrMMA::ValTypeA; + using ComputeTypeB = typename ThrMMA::ValTypeB; + using ComputeTypeC = typename ThrMMA::ValTypeC; + + + // + // MMA Partitioning + // + + // Create register tensors for the MMA to operate on + Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCrAi = make_fragment_like(tCrA); + Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCrBi = make_fragment_like(tCrB); + + using CopyOpAType = SmemCopyOpA; + using CopyOpBType = SmemCopyOpB; + + auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); + Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); + Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K + // + // PREFETCH + // + + copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrAi_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrBi_copy_view(_,_,Int<0>{})); + // + // MAINLOOP + // + + constexpr int K_BLOCK_MAX = size<2>(tCrA); + + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + // static-if load the next k_block. No k-predication required on these loads. + if (k_block < K_BLOCK_MAX-1) + { + // Load the next k_block + int k_next = k_block + 1; // statically unrolled + copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrAi_copy_view(_,_,k_next)); + copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrBi_copy_view(_,_,k_next)); + } + + // Transform A and B, relying on the compiler to remove in case of identity ops + cute::transform(tCrAi(_,_,k_block), tCrA(_,_,k_block), sA_load_op); + cute::transform(tCrBi(_,_,k_block), tCrB(_,_,k_block), sB_load_op); + + // GEMM on k_block in registers + gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } +} + +} // end namespace detail + +// C passed as a shared memory tensor +// Epilogue included +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor & sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyLdOpC const& sC_copy_ld_op = {}, + SmemCopyStOpC const& sC_copy_st_op = {}) +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename TiledMMA::ValTypeA; + using ComputeTypeB = typename TiledMMA::ValTypeB; + using ComputeTypeC = typename TiledMMA::ValTypeC; + + auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) :: InputTypeC + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) :: ComputeTypeC + + // Clear accumulators + clear(tCrC); + if constexpr (is_constant::value) { + detail::cooperative_gemm_no_predication( + thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op + ); + detail::epilogue_no_predication( + thread_idx, thr_mma,alpha, tCrC, beta, sC, sC_load_op, sC_store_op, sC_copy_ld_op, sC_copy_st_op + ); + } else { + detail::cooperative_gemm_predication( + thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op + ); + detail::epilogue_predication( + thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op + ); + } +} + +// C already partitioned into registers on input +// It can be passed non-empty +// Epilogue not included +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}) +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename TiledMMA::ValTypeA; + using ComputeTypeB = typename TiledMMA::ValTypeB; + using ComputeTypeC = typename TiledMMA::ValTypeC; + + // Check if input C fragment is compatible with thr_mma and problem size + using ref_c_frag = decltype(partition_shape_C(tiled_mma, make_shape(size<0>(sA), size<0>(sB)))); + CUTE_STATIC_ASSERT_V(compatible(shape(ref_c_frag{}), shape(tCrC))); + + auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + + if constexpr (is_constant::value) { + detail::cooperative_gemm_no_predication( + thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op + ); + } else { + detail::cooperative_gemm_predication( + thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op + ); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor && sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyLdOpC const& sC_copy_ld_op = {}, + SmemCopyStOpC const& sC_copy_st_op = {}) +{ + cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op, + sA_copy_op, sB_copy_op, sC_copy_ld_op, sC_copy_st_op); +} + +// Legacy overload of cute::gemm for backwards-compatibility +template +CUTE_HOST_DEVICE +void +gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor & sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) + + // Goes directly to the slow path to avoid getting thread_idx from thr_mma + detail::cooperative_gemm_predication( + thr_mma, sA, sB, sC, sA_load_op, sB_load_op + ); + + detail::epilogue_predication( + thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op + ); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/copy.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/copy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..66fc49c0e67f5b6197bd488276989c0050a8c149 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/copy.hpp @@ -0,0 +1,558 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::Copy_Atom + +namespace cute +{ + +// +// copy_if -- Predicated Copy +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + + CUTE_UNROLL + for (int i = 0; i < size(dst); ++i) { + if (pred(i)) { + dst(i) = static_cast(static_cast(src(i))); + } + } +} + +// +// copy_if -- Predicated CopyAtom +// + +// Predicate Tensor is an Actual Tensor +template +CUTE_HOST_DEVICE +void +copy_if(Copy_Atom const& copy_atom, + Tensor const& prd, // ([V],Rest...) + Tensor const& src, // ( V, Rest...) + Tensor & dst) // ( V, Rest...) +{ + if constexpr (PrdLayout::rank == SrcLayout::rank - 1) { + // Back-compat ONLY -- Delete? + copy_if(copy_atom, make_tensor(prd.data(), prepend(prd.layout(), Layout<_1,_0>{})), src, dst); + } else { + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + static_assert(SrcLayout::rank == PrdLayout::rank, "CopyAtom rank-mismatch."); + + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + copy_atom.call(prd, src, dst); + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + Tensor prd_v = group_modes<1,R>(prd); + Tensor src_v = group_modes<1,R>(src); + Tensor dst_v = group_modes<1,R>(dst); + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.call(prd_v(_,i), src_v(_,i), dst_v(_,i)); + } + } + } +} + +template +[[deprecated("Use a bool-tensor or transform-tensor as predication.")]] +CUTE_HOST_DEVICE +void +copy_if(Copy_Atom const& copy_atom, + PredTensor const& pred, // (Rest...) + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + Tensor tpred = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(dst), _1{})), pred); + return copy_if(copy_atom, tpred, src, dst); +} + +// +// copy_if -- AutoCopyAsync +// + +template +CUTE_HOST_DEVICE +void +copy_if(AutoCopyAsync const& cpy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + using SrcElemWithConst = remove_reference_t; + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + + auto copy_op = []() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (is_gmem::value && is_smem::value && + sizeof(SrcType) == sizeof(DstType)) { + if constexpr (is_const_v && sizeof(SrcType) == 16) { + return SM80_CP_ASYNC_CACHEGLOBAL{}; + } else if constexpr (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16) { + return SM80_CP_ASYNC_CACHEALWAYS{}; + } else { + return UniversalCopy{}; + } + } else { + return UniversalCopy{}; + } + + CUTE_GCC_UNREACHABLE; +#else + return UniversalCopy{}; +#endif + }(); + + CUTE_UNROLL + for (int i = 0; i < size(dst); ++i) { + if (pred(i)) { + copy_op.copy(src(i), dst(i)); + } + } +} + +// +// copy -- AutoCopyAsync +// + +template +CUTE_HOST_DEVICE +void +copy(AutoCopyAsync const& cpy, + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + copy_if(cpy, constant_fn{}, src, dst); +} + +// +// copy -- CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom const& copy_atom, + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + copy_atom.call(src, dst); + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + Tensor src_v = group_modes<1,R>(src); + Tensor dst_v = group_modes<1,R>(dst); + + if constexpr (is_static::value && is_static::value) { + CUTE_STATIC_ASSERT_V(size<1>(src_v) == size<1>(dst_v)); + + // AutoFilter on the Rest-mode + auto dst_null = nullspace(layout<1>(dst_v)); + + Tensor dst_n = zipped_divide(dst_v, make_tile(shape<0>(dst_v), dst_null)); // ((V, NLL), (_1, Rest)) + Tensor src_n = zipped_divide(src_v, make_tile(shape<0>(src_v), dst_null)); // ((V, NLL), (_1, Rest)) + + CUTE_STATIC_ASSERT_V(size<1>(src_n) == size<1>(dst_n)); + CUTE_STATIC_ASSERT_V((cosize<0,1>(dst_n.layout()) == Int<1>{}), "Nullspace definition error"); + CUTE_STATIC_ASSERT_V((cosize<0,1>(src_n.layout()) == Int<1>{}), "Error: Ambiguous scatter detected in copy"); + CUTE_STATIC_ASSERT_V((size<1,0>(dst_n) == Int<1>{})); + CUTE_STATIC_ASSERT_V((size<1,0>(src_n) == Int<1>{})); + + Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) + Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) + + CUTE_STATIC_ASSERT_V( size<1>(src_c) == size<1>(dst_c)); + CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst)); + CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src)); + + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_c); ++i) { + copy_atom.call(src_c(_,i), dst_c(_,i)); + } + } else { + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.call(src_v(_,i), dst_v(_,i)); + } + } + } +} + +//////////////////////////////////////////////////////// +// Special Auto-Vectorizing, Auto-Filtering Overloads // +//////////////////////////////////////////////////////// + +// Specialization for AutoVectorizingCopyAssumedAlignment +template +CUTE_HOST_DEVICE +void +copy(AutoVectorizingCopyWithAssumedAlignment const&, + Tensor const& src, + Tensor & dst) +{ + constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst)); + static_assert(is_integral{} * sizeof_bits_v)>::value, "Error: Attempting a subbit write!"); + + if constexpr (common_elem > 1) + { + constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int{})); + constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); + + if constexpr ((vec_bits % 8) == 0 && sizeof_bits_v < Int{}) + { + // If more than one element vectorizes to a multiple of 8bits that is larger than the value_type, then recast and copy + using VecType = uint_bit_t; + + // Recast + Tensor src_v = recast(src); + Tensor dst_v = recast(dst); + return copy_if(constant_fn{}, src_v, dst_v); + } else { + return copy_if(constant_fn{}, src, dst); + } + } else { + return copy_if(constant_fn{}, src, dst); + } +} + +template +struct AutoFilter { + Base const& base; + CUTE_HOST_DEVICE AutoFilter(Base const& b) : base(b) {} +}; + +// Specialization for AutoFilter +template +CUTE_HOST_DEVICE +void +copy(AutoFilter const& copy_op, + Tensor const& src, + Tensor & dst) +{ + if constexpr (is_constant::value) { + auto dst_null = nullspace(dst.layout()); + + Tensor dst_n = zipped_divide(dst, dst_null); + Tensor src_n = zipped_divide(src, dst_null); + + CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error"); + CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous race-condition detected."); + + copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_)); + } else { + copy(copy_op.base, src, dst); + } +} + +// Auto-vectorizing copy for static layouts +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor & dst) +{ + if constexpr (is_static::value && is_static::value) { + // Assume Tensors with static layouts (e.g. registers) have pointers that are 128b aligned + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst); + } else + if constexpr (is_static::value && is_static::value) { + // Tensors with static shapes can be filtered, but do not assume that dynamic layouts are aligned. + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<8>{}), src, dst); + } else { + // Do not assume that dynamic layouts are aligned. + return copy(AutoVectorizingCopyWithAssumedAlignment<8>{}, src, dst); + } +} + +// Auto-vectorizing copy with assumed alignment up to 128bit. +template +CUTE_HOST_DEVICE +void +copy_aligned(Tensor const& src, + Tensor & dst) +{ + if constexpr (is_static::value && is_static::value) { + // Tensors with static shapes can be filtered + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst); + } else { + return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); + } +} + +// Specializaton for Atom AutoVectorizingCopyAssumedAlignment +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom, Args...> const&, + Tensor const& src, + Tensor & dst) +{ + return copy(AutoVectorizingCopyWithAssumedAlignment{}, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom>, Args...> const&, + Tensor const& src, + Tensor & dst) +{ + return copy(AutoVectorizingCopyWithAssumedAlignment{}, src, dst); +} + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +template +CUTE_HOST_DEVICE +void +copy(Copy_Traits const& atom, // Copy_Traits may or may not have the memory barrier in it already + Tensor const& src, + Tensor & dst) +{ + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + static_assert(cute::is_same::value); + static_assert((is_gmem::value && is_smem::value) || + (is_smem::value && is_gmem::value), + "Bulk Copy only supports gmem -> smem or smem -> gmem movement."); + // G2S or S2G dispatch + using BULK_COPY_OP = conditional_t::value, + SM90_BULK_COPY_G2S, + SM90_BULK_COPY_S2G>; + + // Find the common subtensor of src and dst + auto tiler = max_common_layout(src, dst); + constexpr int vec_elem = decltype(size(tiler))::value; + constexpr int vec_bits = vec_elem * sizeof_bits_v; + static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP"); + + // Construct a new concrete Atom of the vector size + using BulkAtom = Copy_Atom, CT_Args...>, SrcType>; + auto bulk_atom = apply(atom.opargs_, [](auto const&... args) { return BulkAtom{args...}; }); + return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler)); +} + +// Backwards-compat. Throw out any extra Copy_Atom args. +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom, CA_Args...> const& atom, + Tensor const& src, + Tensor & dst) +{ + return copy(static_cast const&>(atom), src, dst); +} +#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) + +// +// Decay TiledCopy to CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy_if(TiledCopy const& tiled_copy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + return copy_if(static_cast(tiled_copy), pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(TiledCopy const& tiled_copy, + Tensor const& src, + Tensor & dst) +{ + return copy(static_cast(tiled_copy), src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(ThrCopy const& thr_copy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) = delete; + +template +CUTE_HOST_DEVICE +void +copy(ThrCopy const& thr_copy, + Tensor const& src, + Tensor & dst) = delete; + +// +// Catch uncaught policies +// + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& cpy, + PredTensor const& prd, + Tensor const& src, + Tensor & dst) +{ + static_assert(dependent_false, "Unrecognized CopyPolicy."); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& cpy, + Tensor const& src, + Tensor & dst) +{ + static_assert(dependent_false, "Unrecognized CopyPolicy."); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& copy_policy, + PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(copy_policy, pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor && dst) +{ + return copy(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& copy_policy, + Tensor const& src, + Tensor && dst) +{ + return copy(copy_policy, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_aligned(Tensor const& src, + Tensor && dst) +{ + return copy_aligned(src, dst); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/fill.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/fill.hpp new file mode 100644 index 0000000000000000000000000000000000000000..37b97f18734e74bcdcc39fede3bb6aeaa26c1338 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/fill.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +fill(Tensor&& tensor, T const& value) +{ + return fill(tensor, value); +} + +namespace detail +{ + +// Prefer fill(tensor.data(), value), if possible +template +CUTE_HOST_DEVICE +auto +fill(Tensor& tensor, T const& value, prefer<1>) + -> decltype(fill(tensor.data(), value)) +{ + fill(tensor.data(), value); +} + +// Default implementation +template +CUTE_HOST_DEVICE +void +fill(Tensor& tensor, T const& value, prefer<0>) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = value; + } +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE +void +fill(Tensor& tensor, T const& value) +{ + return detail::fill(tensor, value, prefer<1>{}); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/functional.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/functional.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5c56eb5cc80be460895802df09363eb7254a8455 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/functional.hpp @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::max, cute::min +#include // cute::conj + +/** C++14 extensions */ + +namespace cute { + +/**************/ +/** Identity **/ +/**************/ + +struct identity { + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg); + } +}; + +template +struct constant_fn { + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&&...) const { + return r_; + } + R r_; +}; + +/***********/ +/** Unary **/ +/***********/ + +#define CUTE_LEFT_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return OP static_cast(arg); \ + } \ + } +#define CUTE_RIGHT_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return static_cast(arg) OP ; \ + } \ + } +#define CUTE_NAMED_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return OP (static_cast(arg)); \ + } \ + } + +CUTE_LEFT_UNARY_OP(unary_plus, +); +CUTE_LEFT_UNARY_OP(negate, -); +CUTE_LEFT_UNARY_OP(bit_not, ~); +CUTE_LEFT_UNARY_OP(logical_not, !); +CUTE_LEFT_UNARY_OP(dereference, *); +CUTE_LEFT_UNARY_OP(address_of, &); +CUTE_LEFT_UNARY_OP(pre_increment, ++); +CUTE_LEFT_UNARY_OP(pre_decrement, --); + +CUTE_RIGHT_UNARY_OP(post_increment, ++); +CUTE_RIGHT_UNARY_OP(post_decrement, --); + +CUTE_NAMED_UNARY_OP(abs_fn, abs); +CUTE_NAMED_UNARY_OP(conjugate, cute::conj); + +#undef CUTE_LEFT_UNARY_OP +#undef CUTE_RIGHT_UNARY_OP +#undef CUTE_NAMED_UNARY_OP + +template +struct shift_right_const { + static constexpr int Shift = Shift_; + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg) >> Shift; + } +}; + +template +struct shift_left_const { + static constexpr int Shift = Shift_; + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg) << Shift; + } +}; + +/************/ +/** Binary **/ +/************/ + +#define CUTE_BINARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& lhs, U&& rhs) const { \ + return static_cast(lhs) OP static_cast(rhs); \ + } \ + } +#define CUTE_NAMED_BINARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& lhs, U&& rhs) const { \ + return OP (static_cast(lhs), static_cast(rhs)); \ + } \ + } + + +CUTE_BINARY_OP(plus, +); +CUTE_BINARY_OP(minus, -); +CUTE_BINARY_OP(multiplies, *); +CUTE_BINARY_OP(divides, /); +CUTE_BINARY_OP(modulus, %); + +CUTE_BINARY_OP(plus_assign, +=); +CUTE_BINARY_OP(minus_assign, -=); +CUTE_BINARY_OP(multiplies_assign, *=); +CUTE_BINARY_OP(divides_assign, /=); +CUTE_BINARY_OP(modulus_assign, %=); + +CUTE_BINARY_OP(bit_and, &); +CUTE_BINARY_OP(bit_or, |); +CUTE_BINARY_OP(bit_xor, ^); +CUTE_BINARY_OP(left_shift, <<); +CUTE_BINARY_OP(right_shift, >>); + +CUTE_BINARY_OP(bit_and_assign, &=); +CUTE_BINARY_OP(bit_or_assign, |=); +CUTE_BINARY_OP(bit_xor_assign, ^=); +CUTE_BINARY_OP(left_shift_assign, <<=); +CUTE_BINARY_OP(right_shift_assign, >>=); + +CUTE_BINARY_OP(logical_and, &&); +CUTE_BINARY_OP(logical_or, ||); + +CUTE_BINARY_OP(equal_to, ==); +CUTE_BINARY_OP(not_equal_to, !=); +CUTE_BINARY_OP(greater, >); +CUTE_BINARY_OP(less, <); +CUTE_BINARY_OP(greater_equal, >=); +CUTE_BINARY_OP(less_equal, <=); + +CUTE_NAMED_BINARY_OP(max_fn, cute::max); +CUTE_NAMED_BINARY_OP(min_fn, cute::min); + +#undef CUTE_BINARY_OP +#undef CUTE_NAMED_BINARY_OP + +/**********/ +/** Fold **/ +/**********/ + +#define CUTE_FOLD_OP(NAME,OP) \ + struct NAME##_unary_rfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(T&&... t) const { \ + return (t OP ...); \ + } \ + }; \ + struct NAME##_unary_lfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(T&&... t) const { \ + return (... OP t); \ + } \ + }; \ + struct NAME##_binary_rfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(U&& u, T&&... t) const { \ + return (t OP ... OP u); \ + } \ + }; \ + struct NAME##_binary_lfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(U&& u, T&&... t) const { \ + return (u OP ... OP t); \ + } \ + } + +CUTE_FOLD_OP(plus, +); +CUTE_FOLD_OP(minus, -); +CUTE_FOLD_OP(multiplies, *); +CUTE_FOLD_OP(divides, /); +CUTE_FOLD_OP(modulus, %); + +CUTE_FOLD_OP(plus_assign, +=); +CUTE_FOLD_OP(minus_assign, -=); +CUTE_FOLD_OP(multiplies_assign, *=); +CUTE_FOLD_OP(divides_assign, /=); +CUTE_FOLD_OP(modulus_assign, %=); + +CUTE_FOLD_OP(bit_and, &); +CUTE_FOLD_OP(bit_or, |); +CUTE_FOLD_OP(bit_xor, ^); +CUTE_FOLD_OP(left_shift, <<); +CUTE_FOLD_OP(right_shift, >>); + +CUTE_FOLD_OP(bit_and_assign, &=); +CUTE_FOLD_OP(bit_or_assign, |=); +CUTE_FOLD_OP(bit_xor_assign, ^=); +CUTE_FOLD_OP(left_shift_assign, <<=); +CUTE_FOLD_OP(right_shift_assign, >>=); + +CUTE_FOLD_OP(logical_and, &&); +CUTE_FOLD_OP(logical_or, ||); + +CUTE_FOLD_OP(equal_to, ==); +CUTE_FOLD_OP(not_equal_to, !=); +CUTE_FOLD_OP(greater, >); +CUTE_FOLD_OP(less, <); +CUTE_FOLD_OP(greater_equal, >=); +CUTE_FOLD_OP(less_equal, <=); + +#undef CUTE_FOLD_OP + +/**********/ +/** Meta **/ +/**********/ + +template +struct bound_fn { + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(T&& arg) { + return fn_(arg_, static_cast(arg)); + } + + Fn fn_; + Arg arg_; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +bind(Fn const& fn, Arg const& arg) { + return bound_fn{fn, arg}; +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/gemm.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..97839c044f66b7f955605cc6c58efaa0ec45e6ff --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/gemm.hpp @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +#include + +#include + +/** The gemm algorithm takes four (or three) tensors and computes + * D = A * B + C + * It dispatches based on the number of modes each tensor has: + * + * 1. `(V) x (V) => (V)`. + * The element-wise product of vectors. Dispatches to FMA or MMA. + * 2. `(M) x (N) => (M,N)`. + * The outer product of vectors. Dispatches to [3] with new mode K=(1). + * 3. `(M,K) x (N,K) => (M,N)`. + * The product of matrices. Dispatches to [5] with MMA vector-mode V. + * 4. `(V,M) x (V,N) => (V,M,N)`. + * The batched outer product of vectors. Accounts for register reuse and dispatches to [1] for each (m,n). + * 5. `(V,M,K) x (V,N,K) => (V,M,N)`. + * The batched product of matrices. Dispatches to [4] for each (k). + */ + +namespace cute +{ + +// +// Three arguments to four +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor const& A, + Tensor const& B, + Tensor & C) +{ + return gemm(C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor const& A, + Tensor const& B, + Tensor & C) +{ + return gemm(mma, C, A, B, C); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor const& A, + Tensor const& B, + Tensor && C) +{ + return gemm(C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + return gemm(D, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor const& A, + Tensor const& B, + Tensor && C) +{ + return gemm(mma, C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + return gemm(mma, D, A, B, C); +} + +// +// Default MMA is UniversalFMA +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + using MMA = MMA_Atom::value_type, + typename Tensor::value_type, + typename Tensor::value_type, + typename Tensor::value_type>>; + + return gemm(MMA{}, D, A, B, C); +} + +// +// Thread-Local Register-Memory GEMMs +// + +// Dispatch [1]: (V) x (V) => (V) +template ::value && + ALayout::rank == 1 && is_rmem::value && + BLayout::rank == 1 && is_rmem::value && + CLayout::rank == 1 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V) Logical data + Tensor const& A, // (V) Logical data + Tensor const& B, // (V) Logical data + Tensor const& C) // (V) Logical data +{ + // No static assertions on (V), MMA checks compatibility + mma.call(D, A, B, C); +} + +// Dispatch [2]: (M) x (N) => (M,N) +template ::value && + ALayout::rank == 1 && is_rmem::value && + BLayout::rank == 1 && is_rmem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M) Logical data + Tensor const& B, // (N) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + gemm(mma, + D, // (M,N) + make_tensor(A.data(), append<2>(A.layout())), // (M,1) + make_tensor(B.data(), append<2>(B.layout())), // (N,1) + C); // (M,N) +} + +// Dispatch [3]: (M,K) x (N,K) => (M,N) +template ::value && + ALayout::rank == 2 && is_rmem::value && + BLayout::rank == 2 && is_rmem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M,K) Logical data + Tensor const& B, // (N,K) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + // Assert this is a 1-value MMA + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); + + gemm(mma, + make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) + make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) + make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) + make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) +} + +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +template ::value && + ALayout::rank == 2 && is_rmem::value && + BLayout::rank == 2 && is_rmem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M) Logical data + Tensor const& B, // (V,N) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + auto M = size<1>(A); + auto N = size<1>(B); + // REGISTER .reuse OPTIMIZATIONS + // 64-bit traversal specialization -- serpentine path + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) + { +#if 1 // NOTE: Row- vs Col- major could depend on the C-matrix order... (which we can test) + // Row-major serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate + gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); + } + } +#else + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } +#endif + } else + // 32-bit traversal specialization -- kinked serpentine path + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) + { +#if 1 // NOTE: Row- vs Col- major could depend on the C-matrix order... (which we can test) + // Row-major kinked serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; m += 2) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 2) ? N-1-n : n; + gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns)); + + if (m+1 < M) { + gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns)); + } + } + } +#else + // Col-major kinked serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; n += 2) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + // Kinked serpentine traversal for maximum register reuse + int ms = (n & 2) ? M-1-m : m; + gemm(mma, D(_,ms,n+0), A(_,ms), B(_,n+0), C(_,ms,n+0)); + + if (n+1 < N) { + gemm(mma, D(_,ms,n+1), A(_,ms), B(_,n+1), C(_,ms,n+1)); + } + } + } +#endif + } else + // 64-bit + 32-bit traversal order -- keep A (64-bit) in the outer loop and serpentine B + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) { + // Row-major serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate + gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); + } + } + } else + // 32-bit + 64-bit traversal order -- keep B (64-bit) in the outer loop and serpentine A + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) { + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } + } else + // Fallback to serpentine loop + { + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } + } +} + +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +template ::value && + ALayout::rank == 3 && is_rmem::value && + BLayout::rank == 3 && is_rmem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M,K) Logical data + Tensor const& B, // (V,N,K) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + auto K = size<2>(A); + + CUTE_UNROLL + for (int k = 0; k < K; ++k) { + gemm(mma, D, A(_,_,k), B(_,_,k), C); + } +} + +// +// Thread-Local Shared-Memory GEMMs +// + +// Dispatch [1]: (V) x (V) => (V) +// Dispatch [2]: (M) x (N) => (M,N) +// Dispatch [3]: (M,K) x (N,K) => (M,N) +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +// Dispatch [3]: (M,K) x (N,K) => (M,N) +template ::value && + ALayout::rank == 2 && is_smem::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M,K) Logical data + Tensor const& B, // (N,K) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + // Assert this is a 1-value MMA + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); + + gemm(mma, + make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) + make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) + make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) + make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) +} + +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +template ::value && + ALayout::rank == 3 && is_smem::value && + BLayout::rank == 3 && is_smem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M,K) Logical data + Tensor const& B, // (V,N,K) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + + auto rA = MMA_Atom::make_fragment_A(A); + auto rB = MMA_Atom::make_fragment_B(B); + + auto K = size<2>(A); + + CUTE_UNROLL + for (int k = 0; k < K; ++k) + { + copy(A(_,_,k), rA(_,_,k)); + copy(B(_,_,k), rB(_,_,k)); + // Thread-level register gemm for k + gemm(mma, D, rA(_,_,k), rB(_,_,k), C); + } +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefer.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a1c53e7d26fa9beb15cbabc61418dfe2a3acbc4 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefer.hpp @@ -0,0 +1,46 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +namespace cute +{ + +// Infinite types that inherit from each other +template +struct prefer : prefer {}; + +template <> +struct prefer<0> {}; + +// Can be used to preferencially overload implementations +// Higher N in prefer have higher priority. + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefetch.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefetch.hpp new file mode 100644 index 0000000000000000000000000000000000000000..265da128f169e03798d196fb05d0031f67fc8288 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefetch.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::Copy_Atom + +namespace cute +{ + +// +// Prefetch global tensors into L2 +// + +template +CUTE_HOST_DEVICE +void +cooperative_prefetch(uint32_t const& tid, + Tensor const& src) +{ + static_assert(is_gmem::value, "Expected global tensor for prefetch"); + + constexpr int V = decltype(max_common_vector(src, src))::value; + + if constexpr (V > 1) { + // L2 sector is 32B, default fetch granularity is 64B + using VecType = conditional_t<(V * sizeof_bits_v) < (FetchBytes * 8), + ArrayEngine, + uint8_t[FetchBytes] >; + + Tensor src_v = recast(src); + CUTE_UNROLL + for (int i = tid; i < size(src_v); i += NumThreads) { + prefetch(raw_pointer_cast(&src_v(i))); + } + } else { + CUTE_UNROLL + for (int i = tid; i < size(src); i += NumThreads) { + prefetch(raw_pointer_cast(&src(i))); + } + } +} + +template +CUTE_HOST_DEVICE +void +prefetch(Tensor const& src) +{ + return cooperative_prefetch<1>(0, src); +} + +// Prefetch with copy atom +namespace detail { + +template +constexpr bool has_prefetch = false; + +template +constexpr bool has_prefetch> = true; + +} // end namespace detail + +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Atom, CopyType> const& atom, + Tensor const& src) +{ + if constexpr (detail::has_prefetch) { + using Prefetch_Traits = Copy_Traits; + using Prefetch_Atom = Copy_Atom; + Prefetch_Atom prefetch_atom{atom}; + //auto& dst = const_cast&>(src); // dst is ignored for prefetch atoms + Tensor dst = make_tensor(make_smem_ptr(nullptr), shape(src)); + return copy(prefetch_atom, src, dst); + } else { + return prefetch(src); + } +} + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Traits const& atom, + Tensor const& src) +{ + using SrcType = typename SrcEngine::value_type; + static_assert(is_gmem::value, "Expected global tensor for L2 prefetch"); + + auto tiler = max_common_layout(src, src); + constexpr int vec_elem = decltype(size(tiler))::value; + constexpr int vec_bits = vec_elem * sizeof_bits_v; + static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP"); + + // Construct a new concrete Atom of the vector size + auto bulk_atom = Copy_Atom>, SrcType>{}; + + return prefetch(bulk_atom, logical_divide(src, tiler)); +} + +// Backwards-compat. Throw out any extra Copy_Atom args. +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Atom, CA_Args...> const& atom, + Tensor const& src) +{ + return prefetch(static_cast const&>(atom), src); +} +#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_algorithms.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_algorithms.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f47becb18d96c7bcbaf3380ade318b5abaa4dc2d --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_algorithms.hpp @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** Common algorithms on (hierarchical) tensors */ + +#pragma once + +#include +#include + +namespace cute +{ + +// +// for_each +// + +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor const& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + op(tensor(i)); + } +} + +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + op(tensor(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor&& tensor, UnaryOp&& op) +{ + return for_each(tensor, op); +} + +// +// transform +// + +// Similar to std::transform but does not return number of elements affected +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = op(tensor(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor&& tensor, UnaryOp&& op) +{ + return transform(tensor, op); +} + +// Similar to std::transform transforms one tensors and assigns it to another +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in, + Tensor & tensor_out, + UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor_in); ++i) { + tensor_out(i) = op(tensor_in(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in, + Tensor && tensor_out, + UnaryOp&& op) +{ + return transform(tensor_in, tensor_out, op); +} + +// Similar to std::transform with a binary operation +// Takes two tensors as input and one tensor as output. +// Applies the binary_op to tensor_in1 and tensor_in2 and +// assigns it to tensor_out +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in1, + Tensor const& tensor_in2, + Tensor & tensor_out, + BinaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor_in1); ++i) { + tensor_out(i) = op(tensor_in1(i), tensor_in2(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in1, + Tensor const& tensor_in2, + Tensor && tensor_out, + BinaryOp&& op) +{ + return transform(tensor_in1, tensor_in2, tensor_out, op); +} + +namespace lazy { + +template +CUTE_HOST_DEVICE constexpr +auto +transform(cute::Tensor const& t, Fn const& fn) +{ + return cute::make_tensor(cute::make_transform_iter(fn, t.data()), t.layout()); +} + +} // end namespace lazy + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_reduce.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a6f13735c214bec708cf80bf45f956e5d1df29d3 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_reduce.hpp @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include + +namespace cute +{ + +// Reduce @src tensor using binary reduction operator @op and initial value @init and return a scalar. +template +CUTE_HOST_DEVICE constexpr +T +reduce(Tensor const& src, T init, BinaryOp op = {}) +{ + for (auto i = 0; i < size(src); ++i) { + init = op(init, src(i)); + } + return init; +} + +// Reduce @src tensor RedMode using binary reduction operator @op and store the result in @dst tensor +// for each index in @dst/BatchMode. +// @pre @src tensor has rank 2 +// @pre size of @src batch mode is equal to size of @dst batch mode +template +CUTE_HOST_DEVICE constexpr +void +batch_reduce(Tensor const& src, // (RedMode, BatchMode) + Tensor & dst, // (BatchMode) + BinaryOp op = {}) +{ + // Precondition + CUTE_STATIC_ASSERT_V(rank(src) == Int<2>{}); + assert(size<1>(src) == size(dst)); + + for (int i = 0; i < size(dst); ++i) { + dst(i) = reduce(src(_,i), dst(i), op); + } +} + + +// Reduce @src tensor along selected modes specified in @target_profile using binary reduction operator @op +// and store the result in @dst tensor. @target_profile is a tuple where '_' indicates modes to keep and +// integers indicates modes to reduce. +// @pre @target_profile is compatible with @src layout +template +CUTE_HOST_DEVICE constexpr +void +logical_reduce(Tensor const& src, + Tensor & dst, + TargetProfile const& target_profile, + BinaryOp op = {}) +{ + // Precondition + assert(compatible(target_profile, shape(src))); + + auto diced_layout = dice(target_profile, src.layout()); + auto sliced_layout = slice(target_profile, src.layout()); + + auto red_mode = conditional_return{}>(Layout<_1,_0>{}, diced_layout); + auto batch_mode = conditional_return{}>(Layout<_1,_0>{}, sliced_layout); + + auto src_tensor = make_tensor(src.data(), make_layout(red_mode, batch_mode)); + + batch_reduce(src_tensor, dst, op); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tuple_algorithms.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tuple_algorithms.hpp new file mode 100644 index 0000000000000000000000000000000000000000..311e10557f127b6f1f2771bf5379b75cc32cf824 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tuple_algorithms.hpp @@ -0,0 +1,1053 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +/// @file tuple_algorithms.hpp +/// @brief Common algorithms on (hierarchical) tuples +/// +/// Code guidelines and style preferences: +/// +/// For perfect forwarding, don't use std::forward, because it may not +/// be defined in device code when compiling with NVRTC. Instead, use +/// `static_cast(parameter_name)`. +/// +/// CuTe generally does not bother forwarding functions, as +/// reference-qualified member functions are rare in this code base. +/// +/// Throughout CUTLASS, cute::make_tuple always needs to be called +/// namespace-qualified, EVEN If inside the cute namespace and/or in +/// scope of a "using namespace cute" declaration. Otherwise, the +/// compiler may select std::make_tuple instead of cute::make_tuple, +/// due to argument-dependent lookup. + +namespace cute +{ + +// +// Apply (Unpack) +// (t, f) => f(t_0,t_1,...,t_n) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +apply(T&& t, F&& f, seq) +{ + return f(get(static_cast(t))...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +apply(T&& t, F&& f) +{ + return detail::apply(static_cast(t), f, tuple_seq{}); +} + +// +// Transform Apply +// (t, f, g) => g(f(t_0),f(t_1),...) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T&& t, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t)))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t0)), + get(static_cast(t1)))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t0)), + get(static_cast(t1)), + get(static_cast(t2)))...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T&& t, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t))); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t0), static_cast(t1))); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t0), static_cast(t1), static_cast(t2))); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// For Each +// (t, f) => f(t_0),f(t_1),...,f(t_n) +// + +template +CUTE_HOST_DEVICE constexpr +void +for_each(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); + } else { + return f(static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +for_each_leaf(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::apply(static_cast(t), [&](auto&&... a){ return (for_each_leaf(static_cast(a), f), ...); }, tuple_seq{}); + } else { + return f(static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Transform +// (t, f) => (f(t_0),f(t_1),...,f(t_n)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T0 const& t0, T1 const& t1, F&& f) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t0, t1); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t0, t1, t2); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_leaf(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return transform(t, [&](auto const& a) { return transform_leaf(a, f); }); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_leaf(T0 const& t0, T1 const& t1, F&& f) +{ + if constexpr (is_tuple::value) { + return transform(t0, t1, [&](auto const& a, auto const& b) { return transform_leaf(a, b, f); }); + } else { + return f(t0, t1); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// find and find_if +// + +template +CUTE_HOST_DEVICE constexpr +auto +find_if(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [] (auto... a) { return cute::C>{}; }, tuple_seq{}); + } else { + return cute::C{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +find(T const& t, X const& x) +{ + return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false +} + +template +CUTE_HOST_DEVICE constexpr +auto +any_of(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [] (auto... a) { return (false_type{} || ... || a); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +all_of(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [] (auto... a) { return (true_type{} && ... && a); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +none_of(T const& t, F&& f) +{ + return not any_of(t, f); +} + +// +// Filter +// (t, f) => +// + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T const& t, F&& f) +{ + return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T0 const& t0, T1 const& t1, F&& f) +{ + return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) +{ + return transform_apply(t0, t1, t2, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +// +// Fold (Reduce, Accumulate) +// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n) +// + +namespace detail { + +template +struct FoldAdaptor { + template + CUTE_HOST_DEVICE constexpr auto operator|(X&& x) { + auto r = fn_(val_, static_cast(x)); + return FoldAdaptor{fn_, r}; + } + Fn fn_; + Val val_; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V const& v, F&& f, seq) +{ + return (FoldAdaptor{f,v} | ... | get(static_cast(t))).val_; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V const& v, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::fold(static_cast(t), v, f, tuple_seq{}); + } else { + return f(v, static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +fold_first(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::fold(static_cast(t), get<0>(t), f, make_range<1,tuple_size>::value>{}); + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// front, back, take, select, unwrap +// + +// Get the first non-tuple element in a hierarchical tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +front(T&& t) +{ + if constexpr (is_tuple>::value) { + return front(get<0>(static_cast(t))); + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Get the last non-tuple element in a hierarchical tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +back(T&& t) +{ + if constexpr (is_tuple>::value) { + constexpr int N = tuple_size>::value; + + // MSVC needs a bit of extra help here deducing return types. + // We help it by peeling off the nonrecursive case a level "early." + if constexpr (! is_tuple(static_cast(t)))>>::value) { + return get(static_cast(t)); + } else { + return back(get(static_cast(t))); + } + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Takes the elements in the range [B,E) +template +CUTE_HOST_DEVICE constexpr +auto +take(T const& t) +{ + if constexpr (E == -1) { + if constexpr (is_tuple::value) { + return take::value>(t); + } else { + return take(t); + } + } else + if constexpr (B <= E) { + return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); + } else { + static_assert(B <= E); + } + + CUTE_GCC_UNREACHABLE; +} + +// Select tuple elements with given indices. +template +CUTE_HOST_DEVICE constexpr +auto +select(T const& t) +{ + return cute::make_tuple(get(t)...); +} + +// Wrap non-tuples into rank-1 tuples or forward +template +CUTE_HOST_DEVICE constexpr +auto +wrap(T const& t) +{ + if constexpr (is_tuple::value) { + return t; + } else { + return cute::make_tuple(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple +template +CUTE_HOST_DEVICE constexpr +auto +unwrap(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 1) { + return unwrap(get<0>(t)); + } else { + return t; + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Flatten and Unflatten +// + +template +struct is_flat : true_type {}; + +template +struct is_flat> : bool_constant<(true && ... && (not is_tuple::value))> {}; + +// Flatten a hierarchical tuple to a tuple of depth one +// and wrap non-tuples into a rank-1 tuple. +template +CUTE_HOST_DEVICE constexpr +auto +flatten_to_tuple(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_flat::value) { // Shortcut for perf + return t; + } else { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } + } else { + return cute::make_tuple(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Flatten a hierarchical tuple to a tuple of depth one +// and leave non-tuple untouched. +template +CUTE_HOST_DEVICE constexpr +auto +flatten(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_flat::value) { // Shortcut for perf + return t; + } else { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + if constexpr (is_tuple::value) { + return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) { + auto [result, remaining_tuple] = v; + auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t); + return cute::make_tuple(append(result, sub_result), sub_tuple); + }); + } else { + return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple)); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Unflatten a flat tuple into a hierarchical tuple +// @pre flatten(@a flat_tuple) == @a flat_tuple +// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple) +// @post congruent(@a result, @a target_profile) +// @post flatten(@a result) == @a flat_tuple +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile); + CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{}); + return unflatten_tuple; +} + +// +// insert and remove and replace +// + +namespace detail { + +// Shortcut around cute::tuple_cat for common insert/remove/repeat cases +template +CUTE_HOST_DEVICE constexpr +auto +construct(T const& t, X const& x, seq, seq, seq) +{ + return cute::make_tuple(get(t)..., (void(J),x)..., get(t)...); +} + +} // end namespace detail + +// Insert x into the Nth position of the tuple +template +CUTE_HOST_DEVICE constexpr +auto +insert(T const& t, X const& x) +{ + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); +} + +// Remove the Nth element of the tuple +template +CUTE_HOST_DEVICE constexpr +auto +remove(T const& t) +{ + return detail::construct(t, 0, make_seq{}, seq<>{}, make_range::value>{}); +} + +// Replace the Nth element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); + } else { + static_assert(N == 0); + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Replace the first element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace_front(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size::value>{}); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Replace the last element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace_back(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, make_seq::value-1>{}, seq<0>{}, seq<>{}); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Make a tuple of Xs of tuple_size N +// + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_repeat(X const& x) +{ + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); +} + +// +// Make repeated Xs of rank N +// + +template +CUTE_HOST_DEVICE constexpr +auto +repeat(X const& x) +{ + if constexpr (N == 1) { + return x; + } else { + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Make a tuple of Xs the same profile as tuple T +// + +template +CUTE_HOST_DEVICE constexpr +auto +repeat_like(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return transform(t, [&](auto const& a) { return repeat_like(a,x); }); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Group the elements [B,E) of a T into a single element +// e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{}) +// => T<_1,_2,T<_3,_4>,_5,_6>{} +template +CUTE_HOST_DEVICE constexpr +auto +group(T const& t) +{ + if constexpr (not is_tuple::value) { + if constexpr (E == -1) { + return group(t); + } else { + return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range{}); + } + } else + if constexpr (E == -1) { + return group::value>(t); + } else + if constexpr (B <= E) { + return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range::value>{}); + } else { + static_assert(B <= E); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Extend a T to rank N by appending/prepending an element +// + +template +CUTE_HOST_DEVICE constexpr +auto +append(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + if constexpr (N == tuple_size::value) { + return a; + } else { + static_assert(N > tuple_size::value); + return detail::construct(a, x, make_seq::value>{}, make_seq::value>{}, seq<>{}); + } + } else { + if constexpr (N == 1) { + return a; + } else { + return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq{}, seq<>{}); + } + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(a, x, make_seq::value>{}, seq<0>{}, seq<>{}); + } else { + return cute::make_tuple(a, x); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + if constexpr (N == tuple_size::value) { + return a; + } else { + static_assert(N > tuple_size::value); + return detail::construct(a, x, seq<>{}, make_seq::value>{}, make_seq::value>{}); + } + } else { + if constexpr (N == 1) { + return a; + } else { + static_assert(N > 1); + return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq{}, seq<0>{}); + } + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq::value>{}); + } else { + return cute::make_tuple(x, a); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Inclusive scan (prefix sum) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +iscan(T const& t, V const& v, F&& f, seq) +{ + // Apply the function to v and the element at I + auto v_next = f(v, get(t)); + // Replace I with v_next + auto t_next = replace(t, v_next); + +#if 0 + std::cout << "ISCAN i" << I << std::endl; + std::cout << " t " << t << std::endl; + std::cout << " i " << v << std::endl; + std::cout << " f(i,t) " << v_next << std::endl; + std::cout << " t_n " << t_next << std::endl; +#endif + + if constexpr (sizeof...(Is) == 0) { + return t_next; + } else { + return iscan(t_next, v_next, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +iscan(T const& t, V const& v, F&& f) +{ + return detail::iscan(t, v, f, tuple_seq{}); +} + +// +// Exclusive scan (prefix sum) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +escan(T const& t, V const& v, F&& f, seq) +{ + if constexpr (sizeof...(Is) == 0) { + // Replace I with v + return replace(t, v); + } else { + // Apply the function to v and the element at I + auto v_next = f(v, get(t)); + // Replace I with v + auto t_next = replace(t, v); + +#if 0 + std::cout << "ESCAN i" << I << std::endl; + std::cout << " t " << t << std::endl; + std::cout << " i " << v << std::endl; + std::cout << " f(i,t) " << v_next << std::endl; + std::cout << " t_n " << t_next << std::endl; +#endif + + // Recurse + return escan(t_next, v_next, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +escan(T const& t, V const& v, F&& f) +{ + return detail::escan(t, v, f, tuple_seq{}); +} + +// +// Zip (Transpose) +// + +// Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input +// to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +zip_(Ts const&... ts) +{ + return cute::make_tuple(get(ts)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(T const& t, seq, seq) +{ + static_assert(conjunction>::value == tuple_size>::value>...>::value, "Mismatched Ranks"); + return cute::make_tuple(zip_(get(t)...)...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +zip(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple>::value) { + return detail::zip(t, tuple_seq{}, tuple_seq>{}); + } else { + return cute::make_tuple(t); + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// Convenient to pass them in separately +template +CUTE_HOST_DEVICE constexpr +auto +zip(T0 const& t0, T1 const& t1, Ts const&... ts) +{ + return zip(cute::make_tuple(t0, t1, ts...)); +} + +// +// zip2_by -- A guided zip for rank-2 tuples +// Take a tuple like ((A,a),((B,b),(C,c)),d) +// and produce a tuple ((A,(B,C)),(a,(b,c),d)) +// where the rank-2 modes are selected by the terminals of the guide (X,(X,X)) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +zip2_by(T const& t, TG const& guide, seq, seq) +{ + // zip2_by produces the modes like ((A,a),(B,b),...) + auto split = cute::make_tuple(zip2_by(get(t), get(guide))...); + + // Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y)) + return cute::make_tuple(cute::make_tuple(get<0>(get(split))...), + cute::make_tuple(get<1>(get(split))..., get(t)...)); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +zip2_by(T const& t, TG const& guide) +{ + if constexpr (is_tuple::value) { + constexpr int TR = tuple_size::value; + constexpr int GR = tuple_size::value; + static_assert(TR >= GR, "Mismatched ranks"); + return detail::zip2_by(t, guide, + make_range< 0, GR>{}, + make_range{}); + } else { + static_assert(tuple_size::value == 2, "Mismatched ranks"); + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +/// @return A tuple of the elements of @c t in reverse order. +template +CUTE_HOST_DEVICE constexpr +auto +reverse(T const& t) +{ + if constexpr (is_tuple::value) { + return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq{}); + } else { + return t; + } +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm100.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0bcf19b0f218688edc6cc16a41a8e8968e06e611 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm100.hpp @@ -0,0 +1,108 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// + +// + +#pragma once + +#include + + +namespace cute { + + +// +// Cluster launch utility +// +CUTE_HOST +bool +initialize_preferred_cluster_launch(void const* const kernel_function, + dim3 const& grid_dims, + dim3 const& cluster_dims_preferred, + dim3 const& cluster_dims_fallback) +{ + // + // Validate cluster_dims + // + + // Total number of cluster cannot be greater than 32 (hardware requirement) + if (cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z <= 0 || + cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z > 32) { + std::cout << "Invalid preferred cluster dimensions: Attempting to init preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z + << ") [" << (cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z) << "] which must be within (0,32]." << std::endl; + return false; + } + + // Total number of cluster cannot be greater than 32 (hardware requirement) + if (cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z <= 0 || + cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z > 32) { + std::cout << "Invalid cluster dimensions: Attempting to init fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z + << ") [" << (cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z) << "] which must be within (0,32]." << std::endl; + return false; + } + + // Total grid dimensions must be within (2^32, 2^16, 2^16) + if (grid_dims.y > (1 << 16) || grid_dims.z > (1 << 16)) { + std::cout << "Invalid grid dimensions: Attempting to init grid dimensions (" << grid_dims.x << "," << grid_dims.y << "," << grid_dims.z + << ") which must be within (2^32, 2^16, 2^16)." << std::endl; + return false; + } + + // grid_dims should be divisible by cluster_dims_preferred + if (grid_dims.x % cluster_dims_preferred.x != 0 || + grid_dims.y % cluster_dims_preferred.y != 0 || + grid_dims.z % cluster_dims_preferred.z != 0) { + std::cout << "Invalid grid dimensions: Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z + << ") does not divide Grid (" << grid_dims.x << "," << grid_dims.y << "," << grid_dims.z << ")." << std::endl; + return false; + } + + // cluster_dims_preferred should be divisible by cluster_dims_fallback + if (cluster_dims_preferred.x % cluster_dims_fallback.x != 0 || + cluster_dims_preferred.y % cluster_dims_fallback.y != 0 || + cluster_dims_preferred.z % cluster_dims_fallback.z != 0) { + std::cout << "Invalid cluster dimensions: Fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z + << ") does not divide Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z << ")." << std::endl; + return false; + } + + // Both cluster dimenions should have the same depth + if (cluster_dims_preferred.z != cluster_dims_fallback.z) { + std::cout << "Invalid cluster dimensions: Fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z + << ") and Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z << ") does not have the same depth." << std::endl; + return false; + } + + return true; +} +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm90.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..524a47efb11b04465e22541b8c2c12c6f4720af2 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm90.hpp @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))) +# define CUTE_ARCH_CLUSTER_SM90_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +# define CUTE_ARCH_ELECT_ONE_SM90_ENABLED +#endif + +namespace cute { + +CUTE_DEVICE void cluster_arrive_relaxed() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_arrive() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_wait() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.wait.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_sync() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + cluster_arrive(); + cluster_wait(); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +// Returns the dim3 grid size in terms of number of clusters. +CUTE_DEVICE dim3 cluster_grid_dims() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%nclusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%nclusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of gridDim with __CUDA_ARCH__. + return gridDim; +#elif defined(_MSC_VER) + CUTE_INVALID_CONTROL_PATH("cluster_grid_dims() can only be called on device"); + return {0, 0, 0}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the dim3 cluster rank in the grid. +CUTE_DEVICE dim3 cluster_id_in_grid() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%clusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%clusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of blockIdx with __CUDA_ARCH__. + return blockIdx; +#elif defined(_MSC_VER) + CUTE_INVALID_CONTROL_PATH("cluster_id_in_grid() can only be called on device"); + return {0, 0, 0}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the relative dim3 block rank local to the cluster. +CUTE_DEVICE dim3 block_id_in_cluster() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_ctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return {0,0,0}; +#endif +} + +// Returns the dim3 cluster shape. +CUTE_DEVICE dim3 cluster_shape() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_nctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return {1,1,1}; +#endif +} + +// Get 1D ctaid in a cluster. +CUTE_DEVICE uint32_t block_rank_in_cluster() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t rank; + asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :); + return rank; +#else + return 0; +#endif +} + +// Set the destination block-ID in cluster for a given SMEM Address +CUTE_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t result; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(result) + : "r"(smemAddr), "r"(rank)); + return result; +#else + return smemAddr; +#endif +} + +// Elect one thread in the warp. The elected thread gets its predicate set to true, all others obtain false. +CUTE_HOST_DEVICE uint32_t elect_one_sync() +{ +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +#elif defined(__CUDA_ARCH__) + return (threadIdx.x % 32) == 0; +#else + return true; +#endif +} + +struct ElectOneLaneIdReturnType { + uint32_t is_leader; + uint32_t leader_lane_id; +}; + +CUTE_HOST_DEVICE +ElectOneLaneIdReturnType +elect_one_leader_sync() +{ +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return {pred, laneid}; +#elif defined(__CUDA_ARCH__) + return {(threadIdx.x % 32) == 0, 0}; +#else + return {true, 0}; +#endif +} + +// Store value to remote shared memory in the cluster +CUTE_DEVICE +void +store_shared_remote(uint32_t value, uint32_t smem_addr, uint32_t mbarrier_addr, uint32_t dst_cta_rank) +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t dsmem_addr = set_block_rank(smem_addr, dst_cta_rank); + uint32_t remote_barrier_addr = set_block_rank(mbarrier_addr, dst_cta_rank); + asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];" + : : "r"(dsmem_addr), "r"(value), "r"(remote_barrier_addr)); +#endif +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/config.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7ce05f30d3acb0db4b9c9f1531e1e9666432db61 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/config.hpp @@ -0,0 +1,216 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTLASS_ARCH_MMA_SMxx_ENABLED + +// MMA SM90A +#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +// TMA instructions +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +# define CUTE_ARCH_TMA_SM90_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +#endif + +// STSM +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) +# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) +# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + + +// SM110 specific configs +#if (defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +# define CUTE_ARCH_LOAD256_SM100A_ENABLED +# define CUTE_ARCH_STORE256_SM100A_ENABLED +# define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM110A_ENABLED)) +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) +# define CUTE_ARCH_TMA_SM100_ENABLED +#endif + +// {add, mul, fma}.f32x2 PTX +#if defined(CUTLASS_ARCH_MMA_SM100_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) + // Enable CuTe MMA Atoms +# define CUTE_ARCH_FFMA2_SM100_ENABLED + // Enable f32x2 PTX generation +# define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# define CUTE_ARCH_MMA_SM120_ENABLED +# define CUTE_ARCH_TMA_SM120_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) +# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) +# define CUTE_ARCH_F8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED +# endif +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM121_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) +# define CUTE_ARCH_F8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED +# endif +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +# define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) +# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) +# define CUTE_ARCH_LOAD256_SM100A_ENABLED +# define CUTE_ARCH_STORE256_SM100A_ENABLED +# endif +#endif + +// {add, mul, fma}.f32x2 PTX +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) + #define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM103_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED +#endif + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b62fa91f9937afe5dd1dcedfd983743dd6ba076 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy.hpp @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Direct Copy for any specific types +// + +template +struct UniversalCopy +{ + using SRegisters = S[1]; + using DRegisters = D[1]; + + // Sanity + static_assert(sizeof_bits_v >= 8); + static_assert(sizeof_bits_v >= 8); + + CUTE_HOST_DEVICE static constexpr void + copy(S const& src, + D & dst) + { + dst = src; + } +}; + +// +// Placeholder for the copy algorithm's stronger auto-vectorizing behavior +// that assumes alignment of pointers and dynamic layouts up to MaxVecBits +// + +template +struct AutoVectorizingCopyWithAssumedAlignment + : UniversalCopy> +{ + static_assert(MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128, + "Expected MaxVecBits to be 8 or 16 or 32 or 64 or 128 for alignment and performance."); +}; + +// +// AutoVectorizingCopy alias assumes maximal alignment of pointers and dynamic strides. +// If this is not the case then AutoVectorizingCopyWithAssumedAlignment should be used instead +// + +using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; + +// +// DefaultCopy alias does not assume alignment of pointers or dynamic strides. +// + +using DefaultCopy = AutoVectorizingCopyWithAssumedAlignment<8>; + +// +// Copy policy automatically selecting between +// UniversalCopy and cp.async , based on type and memory space. +// +struct AutoCopyAsync {}; + +// +// Global memory prefetch into L2 +// + +CUTE_HOST_DEVICE static void +prefetch(void const* gmem_ptr) +{ +#if defined(__CUDA_ARCH__) + asm volatile("prefetch.global.L2 [%0];\n" : : "l"(gmem_ptr) : "memory"); +#endif +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8a9a67b1d86dc27a30eee2a133ff0b6b696395b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100.hpp @@ -0,0 +1,7615 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Global Memory Load and Store PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_LOAD_256bit_CACHE_NOALLOCATION +{ + using SRegisters = uint256_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint256_t const& gmem_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { + #if defined(CUTE_ARCH_LOAD256_SM100A_ENABLED) + asm volatile("ld.global.L1::no_allocate.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "l"(&gmem_addr) ); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use LOAD.256 without CUTE_ARCH_LOAD256_SM100A_ENABLED."); + #endif + } +}; + +struct SM100_STORE_256bit_CACHE_NOALLOCATION +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint256_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint256_t& gmem_addr) + { + #if defined(CUTE_ARCH_STORE256_SM100A_ENABLED) + asm volatile("st.global.L1::no_allocate.v8.f32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8};\n" + :: "l"(&gmem_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7)); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use stg.256 without CUTE_ARCH_STORE256_SM100A_ENABLED."); + #endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// LDSM PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x8_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t tmp0, tmp1; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8 {%0, %1}, [%2];\n" + : "=r"(reinterpret_cast(tmp0)), "=r"(reinterpret_cast(tmp1)) + : "r"(smem_int_ptr)); + // RefLayout of ldmatrix.m16n16.x1.trans won't match stmatrix.m16n8.x2.trans without additional transformations + // Do this here so we don't need to add an additional reg to reg copy at the collective layer + uchar4& tmp0_ = reinterpret_cast(tmp0); + uchar4& tmp1_ = reinterpret_cast(tmp1); + uchar4 dst0_{tmp0_.x, tmp0_.y, tmp1_.x, tmp1_.y}; + uchar4 dst1_{tmp0_.z, tmp0_.w, tmp1_.z, tmp1_.w}; + dst0 = reinterpret_cast(dst0_); + dst1 = reinterpret_cast(dst1_); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x16_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t tmp0, tmp1, tmp2, tmp3; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8 {%0, %1, %2, %3}, [%4];\n" + : "=r"(reinterpret_cast(tmp0)), "=r"(reinterpret_cast(tmp1)), + "=r"(reinterpret_cast(tmp2)), "=r"(reinterpret_cast(tmp3)) + : "r"(smem_int_ptr)); + uchar4& tmp0_ = reinterpret_cast(tmp0); + uchar4& tmp1_ = reinterpret_cast(tmp1); + uchar4& tmp2_ = reinterpret_cast(tmp2); + uchar4& tmp3_ = reinterpret_cast(tmp3); + uchar4 dst0_{tmp0_.x, tmp0_.y, tmp1_.x, tmp1_.y}; + uchar4 dst1_{tmp0_.z, tmp0_.w, tmp1_.z, tmp1_.w}; + uchar4 dst2_{tmp2_.x, tmp2_.y, tmp3_.x, tmp3_.y}; + uchar4 dst3_{tmp2_.z, tmp2_.w, tmp3_.z, tmp3_.w}; + dst0 = reinterpret_cast(dst0_); + dst1 = reinterpret_cast(dst1_); + dst2 = reinterpret_cast(dst2_); + dst3 = reinterpret_cast(dst3_); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +struct SM100_SU4_DU8x16_x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n" + : "=r"(dst0) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU6_DU8x16_x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b6x16_p32 {%0}, [%1];\n" + : "=r"(dst0) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU4_DU8x16_x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU6_DU8x16_x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b6x16_p32 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU4_DU8x16_x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU6_DU8x16_x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b6x16_p32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// STSM PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x4_STSM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x8_STSM_T +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x16_STSM_T +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.m16n8.x4.trans.shared.b8 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// UTCCP PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::UTCCP { + +// 128 data path lanes, 256-bit pattern, 1cta mode +struct SM100_UTCCP_128dp256bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.128x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 128 data path lanes, 256-bit pattern, 2cta mode +struct SM100_UTCCP_128dp256bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.128x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +struct SM100_UTCCP_128dp128bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.128x128b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +struct SM100_UTCCP_128dp128bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.128x128b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + + +// 4 data path lanes, 256-bit pattern, 1cta mode +struct SM100_UTCCP_4dp256bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.4x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 4 data path lanes, 256-bit pattern, 2cta mode +struct SM100_UTCCP_4dp256bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.4x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 4x32 data path lanes (broadcast), 128-bit pattern, 1cta mode +struct SM100_UTCCP_4x32dp128bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.32x128b.warpx4 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 4x32 data path lanes (broadcast), 128-bit pattern, 2cta mode +struct SM100_UTCCP_4x32dp128bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.32x128b.warpx4 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast like 4x32dp), 128-bit pattern, 1cta mode +struct SM100_UTCCP_2x64dp128bitlw0213_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.64x128b.warpx2::02_13 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast like 4x32dp), 128-bit pattern, 2cta mode +struct SM100_UTCCP_2x64dp128bitlw0213_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.64x128b.warpx2::02_13 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast seperately in upper and lower 64dp), 128-bit pattern, 1cta mode +// data_row[0:31] -> DP[0:63] +// data_row[32:63] -> DP[64:127] +struct SM100_UTCCP_2x64dp128bitlw0123_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.64x128b.warpx2::01_23 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast seperately in upper and lower 64dp), 128-bit pattern, 2cta mode +// data_row[0:31] -> DP[0:63] +// data_row[32:63] -> DP[64:127] +struct SM100_UTCCP_2x64dp128bitlw0123_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.64x128b.warpx2::01_23 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +} // end namespace SM100::TMEM::UTCCP + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::LOAD { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM LOAD PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp256b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x1.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp256b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x2.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp256b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x4.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp256b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp256b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp256b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp128b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x1.pack::16b.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp128b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x2.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x2.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp128b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x4.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x4.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp128b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp128b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp128b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_16dp128b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp64b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x1.pack::16b.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp64b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x2.pack::16b.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp64b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x4.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp64b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp64b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp64b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_16dp64b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times +struct SM100_TMEM_LOAD_16dp64b128x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x128.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b128x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x128.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x1.b32" + "{%0}," + "[%1], 1;\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x1.pack::16b.b32" + "{%0}," + "[%1], 2;\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp32b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x2.b32" + "{%0, %1}," + "[%2], 2;\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x2.pack::16b.b32" + "{%0, %1}," + "[%2], 4;\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp32b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x4.b32" + "{%0, %1, %2, %3}," + "[%4], 4;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x4.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4], 8;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp32b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8], 8;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8], 16;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp32b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16], 16;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16], 32;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp32b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32], 32;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32], 64;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_16dp32b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64], 64;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64], 128;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_LOAD_16dp32b128x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x128.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128], 128;\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b128x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x128.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128], 256;\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_32dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x1.pack::16b.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_32dp32b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x2.pack::16b.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_32dp32b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x4.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_32dp32b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_32dp32b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_32dp32b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_32dp32b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_LOAD_32dp32b128x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x128.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b128x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x128.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM100::TMEM::LOAD + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::STORE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM STORE PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp256b1x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b1x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp256b2x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b2x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp256b4x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x4.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b4x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp256b8x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b8x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp256b16x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b16x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp256b32x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b32x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp128b1x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b1x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.unpack::16b.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp128b2x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b2x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp128b4x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b4x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp128b8x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b8x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp128b16x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b16x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp128b32x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b32x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_16dp128b64x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x64.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b64x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x64.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp64b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.unpack::16b.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp64b2x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b2x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp64b4x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b4x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp64b8x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b8x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp64b16x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b16x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp64b32x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b32x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_16dp64b64x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x64.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b64x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x64.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times +struct SM100_TMEM_STORE_16dp64b128x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x128.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b128x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x128.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.b32" + "[%0] , 1," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.unpack::16b.b32" + "[%0] , 2," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp32b2x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.b32" + "[%0] , 2," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b2x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.unpack::16b.b32" + "[%0] , 4," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp32b4x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.b32" + "[%0] , 4," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b4x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.unpack::16b.b32" + "[%0] , 8," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp32b8x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.b32" + "[%0] , 8," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b8x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.unpack::16b.b32" + "[%0] , 16," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp32b16x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x16.b32" + "[%0] , 16," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b16x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x16.unpack::16b.b32" + "[%0] , 32," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp32b32x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x32.b32" + "[%0] , 32," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b32x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x32.unpack::16b.b32" + "[%0] , 64," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_16dp32b64x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x64.b32" + "[%0] , 64," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b64x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x64.unpack::16b.b32" + "[%0] , 128," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_STORE_16dp32b128x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x128.b32" + "[%0] , 128," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b128x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x128.unpack::16b.b32" + "[%0] , 256," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_32dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.unpack::16b.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_32dp32b2x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b2x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_32dp32b4x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b4x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_32dp32b8x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b8x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_32dp32b16x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b16x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_32dp32b32x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b32x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_32dp32b64x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x64.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b64x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x64.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_STORE_32dp32b128x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x128.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b128x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x128.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM100::TMEM::STORE + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100_tma.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ef6cfe01026946101e9468fd45da549c9d5bcb86 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100_tma.hpp @@ -0,0 +1,666 @@ +/*************************************************************************************************** + * Copyright (c) 2020 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include + +#include +#include +#include "cutlass/arch/synclog.hpp" + +namespace cute +{ + +constexpr uint32_t Sm100MmaPeerBitMask = 0xFEFFFFFF; +constexpr uint64_t Sm100MemDescDefault = uint64_t(0x1000000000000000); + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// UTMA_LOAD : Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_1D +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_2D +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_3D +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM100_TMA_2SM_LOAD_1D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM100_TMA_2SM_LOAD_4D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM100_TMA_2SM_LOAD_5D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; + + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_MULTICAST_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4}], [%2], %3, %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5}], [%2], %3, %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) +uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6}], [%2], %3, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7}], [%2], %3, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM0_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM100_TMA_2SM_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM100_TMA_2SM_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM100_TMA_2SM_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM100_TMA_2SM_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM100_TMA_2SM_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +struct SM100_TMA_2SM_LOAD_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM100_TMA_2SM_LOAD_IM2COL_3D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { + return SM100_TMA_2SM_LOAD_IM2COL_4D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { + return SM100_TMA_2SM_LOAD_IM2COL_5D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), + "h"(multicast_mask), + "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9, %10;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), + "h"(multicast_mask), + "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11, %12;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), + "h"(multicast_mask), + "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH; +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm50.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm50.hpp new file mode 100644 index 0000000000000000000000000000000000000000..12518fccd0781f0444faf771c929252ba952face --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm50.hpp @@ -0,0 +1,98 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 500 + #define CUTE_ARCH_WARP_SHUFFLE_ENABLED 1 +#endif + +namespace cute +{ +// Shuffle data between thread pair (0, 1), (2, 3), etc. +struct SM50_Shuffle_U32_2x2Trans_XOR1 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_WARP_SHUFFLE_ENABLED) + uint32_t x0 = src0; + uint32_t y0 = __shfl_xor_sync(0xffffffff, x0, 1); + + uint32_t x1 = src1; + uint32_t y1 = __shfl_xor_sync(0xffffffff, x1, 1); + + if (threadIdx.x % 2 == 0) { + dst1 = y0; + } + else { + dst0 = y1; + } +#else + CUTE_INVALID_CONTROL_PATH("Trying to use __shfl_xor_sync without CUTE_ARCH_WARP_SHUFFLE_ENABLED."); +#endif + } +}; + +// Shuffle data between thread pair (0, 4), (1, 5), etc. +struct SM50_Shuffle_U32_2x2Trans_XOR4 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_WARP_SHUFFLE_ENABLED) + uint32_t x0 = threadIdx.x & 4 ? src0 : src1; + uint32_t y0 = __shfl_xor_sync(0xffffffff, x0, 4); + + // Replace detination register with shuffle result. + if (threadIdx.x & 0x4) { + dst0 = y0; + } + else { + dst1 = y0; + } +#else + CUTE_INVALID_CONTROL_PATH("Trying to use __shfl_xor_sync without CUTE_ARCH_WARP_SHUFFLE_ENABLED."); +#endif + } +}; + + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm75.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm75.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a6cb2387cd824da13cc16e7efeb2c721969adf4a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm75.hpp @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if defined(__clang__) && defined(__CUDA__) + // ldmatrix PTX instructions added in Clang 14: https://reviews.llvm.org/D107046 + // ... but will not work until Clang 15: + // * https://reviews.llvm.org/D121666 + // * https://reviews.llvm.org/D126846 + #define CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75 (__clang_major__ >= 15) + #define CUTE_ARCH_CLANG_SUPPORTS_MOVM_SM75 (__clang_major__ >= 15) +#endif + +#if defined(__NVCC__) || defined(__CUDACC_RTC__) + // ldmatrix PTX instruction added in CUDA 10.2+ + #define CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11) + #define CUTE_ARCH_NVCC_SUPPORTS_MOVM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11) +#endif + +#if ! defined(CUTE_ARCH_LDSM_SM75_SUPPORTED) + #define CUTE_ARCH_LDSM_SM75_SUPPORTED (CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 || CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75) +#endif + +#if ! defined(CUTE_ARCH_LDSM_SM75_ENABLED) + #define CUTE_ARCH_LDSM_SM75_ENABLED (CUTE_ARCH_LDSM_SM75_SUPPORTED) +#endif + +#if (CUTE_ARCH_LDSM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + #define CUTE_ARCH_LDSM_SM75_ACTIVATED 1 +#endif + +#if ! defined(CUTE_ARCH_MOVM_SM75_SUPPORTED) + #define CUTE_ARCH_MOVM_SM75_SUPPORTED (CUTE_ARCH_NVCC_SUPPORTS_MOVM_SM75 || CUTE_ARCH_CLANG_SUPPORTS_MOVM_SM75) +#endif + +#if ! defined(CUTE_ARCH_MOVM_SM75_ENABLED) + #define CUTE_ARCH_MOVM_SM75_ENABLED (CUTE_ARCH_MOVM_SM75_SUPPORTED) +#endif + +#if (CUTE_ARCH_MOVM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + #define CUTE_ARCH_MOVM_SM75_ACTIVATED 1 +#endif + + +namespace cute +{ + +struct SM75_U32x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U32x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U32x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x2_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x4_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x8_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U32x1_MOVM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t src, + uint32_t &dst) + { +#if CUTE_ARCH_MOVM_SM75_ACTIVATED + asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" + : "=r"(dst) + : "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use movmatrix without CUTE_ARCH_MOVM_SM75_ACTIVATED."); +#endif + } +}; +// +// Legacy LDSM interfaces that aren't very useful +// + +template +CUTE_HOST_DEVICE +void +copy_ldsm(uint128_t const* const smem_ptr, + T* rmem_ptr) +{ + uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM75_U32x1_LDSM_N::copy(smem_ptr[0], reg_ptr[0]); + } + else if (sizeof(T) == 8) { + SM75_U32x2_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); + } + else if (sizeof(T) == 16) { + SM75_U32x4_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +template +CUTE_HOST_DEVICE +void +copy_ldsm_trans(uint128_t const* const smem_ptr, + T* rmem_ptr) +{ + uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM75_U16x2_LDSM_T::copy(smem_ptr[0], reg_ptr[0]); + } + else if (sizeof(T) == 8) { + SM75_U16x4_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); + } + else if (sizeof(T) == 16) { + SM75_U16x8_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm80.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm80.hpp new file mode 100644 index 0000000000000000000000000000000000000000..71a7b3a46616869d79e67586c6672fcd6eaac666 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm80.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +# define CUTE_ARCH_CP_ASYNC_SM80_ENABLED +#endif + +namespace cute +{ + +/// Copy via cp.async with caching at all levels +template +struct SM80_CP_ASYNC_CACHEALWAYS +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS))); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at global level +template +struct SM80_CP_ASYNC_CACHEGLOBAL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS))); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at all levels +template +struct SM80_CP_ASYNC_CACHEALWAYS_ZFILL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst, + bool pred) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + int src_size = pred ? sizeof(TS) : 0; + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS)), + "r"(src_size)); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at global level +template +struct SM80_CP_ASYNC_CACHEGLOBAL_ZFILL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst, + bool pred) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + int src_size = pred ? sizeof(TS) : 0; + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS)), + "r"(src_size)); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +CUTE_HOST_DEVICE +void +cp_async_fence() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Blocks until all but N previous cp.async.commit_group operations have committed. +template +CUTE_HOST_DEVICE +void +cp_async_wait() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (N == 0) { + asm volatile("cp.async.wait_all;\n" ::); + } else { + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); + } +#endif +} + +template +CUTE_HOST_DEVICE +void +cp_async_wait(Int) +{ + return cp_async_wait(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5c0745da0a205f304fed9c9ec257bf0bcf16a7cb --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90.hpp @@ -0,0 +1,219 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // CUTE_ARCH_TMA_SMxx_ENABLED +#include + +namespace cute +{ + +struct SM90_U32x1_STSM_N +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t & smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U32x2_STSM_N +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U32x4_STSM_N +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x2_STSM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x4_STSM_T +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x8_STSM_T +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +// +// Legacy STSM interfaces that aren't very useful +// + +template +CUTE_HOST_DEVICE +void +copy_stsm(T const* const rmem_ptr, + uint128_t* const smem_ptr) +{ + uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM90_U32x1_STSM_N::copy(reg_ptr[0], smem_ptr[0]); + } + else if (sizeof(T) == 8) { + SM90_U32x2_STSM_N::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); + } + else if (sizeof(T) == 16) { + SM90_U32x4_STSM_N::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +template +CUTE_HOST_DEVICE +void +copy_stsm_trans(T const* const rmem_ptr, + uint128_t* const smem_ptr) +{ + uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM90_U16x2_STSM_T::copy(reg_ptr[0], smem_ptr[0]); + } + else if (sizeof(T) == 8) { + SM90_U16x4_STSM_T::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); + } + else if (sizeof(T) == 16) { + SM90_U16x8_STSM_T::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_desc.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_desc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..63bfb563b1a886b40e8b798d217c70e2db073df7 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_desc.hpp @@ -0,0 +1,475 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/numeric_types.h" + +#if !defined(__CUDACC_RTC__) +#include +#include +#endif + +#include + +#include // cute::cast_smem_ptr_to_uint +#include // CUTE_ARCH_TMA_SMxx_ENABLED +#include +#include + +#include +#include +#include +#include + +namespace cute +{ + +////////////////////////////////////////////////////////////////////////////////////////////////////// +/// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns +/// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels) +/// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction) +////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Initialize barrier present in shared memory +CUTE_HOST_DEVICE +void +initialize_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + int thread_count = 1) // Thread count expected to arrive/wait on this barrier +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile ("mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(smem_int_ptr), + "r"(thread_count)); +#endif +} + +// Set the number of bytes transfered per transaction and perform an arrive operation as well +CUTE_HOST_DEVICE +void +set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + uint32_t bytes) // Number of bytes transfered by per TMA transaction +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(smem_int_ptr), + "r"(bytes)); +#endif +} + +// Barrier wait +CUTE_HOST_DEVICE +void +wait_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + int phase_bit) // Current phase bit the barrier waiting to flip +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra DONE;\n" + "bra LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(smem_int_ptr), + "r"(phase_bit)); + +#endif +} + +// Barrier arrive +CUTE_HOST_DEVICE +void +arrive_barrier(uint64_t& smem_barrier) // 64 bits user-manged barrier in smem +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .b64 state; \n" + "mbarrier.arrive.shared::cta.b64 state, [%0];\n" + "}\n" + :: "r"(smem_int_ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// TMA Descriptor and utilities +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace TMA { + +enum class SmemSwizzleBits : uint8_t { + DISABLE = 0, + B32 = 1, + B64 = 2, + B128 = 3, +}; + +enum class SmemSwizzleBase : uint8_t { + SWIZZLE_BASE_16B = 0, + + SWIZZLE_BASE_32B = 1, + SWIZZLE_BASE_32B_FLIP_8B = 2, + SWIZZLE_BASE_64B = 3, + +}; + +enum class OOBFill : uint8_t { + ZERO = 0, + CONSTANT = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(OOBFill const& t) { + switch (t) { + case OOBFill::ZERO: return "ZERO"; + case OOBFill::CONSTANT: return "CONSTANT"; + } + return nullptr; +} + +enum class L2Promotion : uint8_t { + DISABLE = 0, + B64 = 1, + B128 = 2, + B256 = 3, +}; + +CUTE_HOST_DEVICE char const* to_string(L2Promotion const& t) { + switch (t) { + case L2Promotion::DISABLE: return "DISABLE"; + case L2Promotion::B64: return "B64"; + case L2Promotion::B128: return "B128"; + case L2Promotion::B256: return "B256"; + } + return nullptr; +} + +// Aux parameters which are independent with the problem size +struct DescriptorAuxParams { + OOBFill oobfill_ = OOBFill::ZERO; + L2Promotion l2promo_ = L2Promotion::DISABLE; +}; + +enum class CacheHintSm90 : uint64_t { + EVICT_NORMAL = 0x1000000000000000, + EVICT_FIRST = 0x12F0000000000000, + EVICT_LAST = 0x14F0000000000000, +}; + + +enum class CacheHintSm100 : uint64_t { + EVICT_NORMAL = 0x1000000000000000, + EVICT_FIRST = 0x12F0000000000000, + EVICT_LAST = 0x14F0000000000000, +}; + + +#if (__CUDACC_VER_MAJOR__ >= 12) + +#if !defined(__CUDACC_RTC__) +/// @return The TMA descriptor datatype enum corresponding to T. +template +inline CUtensorMapDataType +to_CUtensorMapDataType() { + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ > 6))) + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + #endif + + { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } +} + +inline CUtensorMapSwizzle +to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { + switch (t) { + default: throw std::runtime_error("Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + case SmemSwizzleBits::DISABLE: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 0B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_NONE; + case SmemSwizzleBits::B32: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 32B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_32B; + case SmemSwizzleBits::B64: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 64B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_64B; + case SmemSwizzleBits::B128: + switch (b) { + default: throw std::runtime_error("Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + case SmemSwizzleBase::SWIZZLE_BASE_16B: return CU_TENSOR_MAP_SWIZZLE_128B; + + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ > 6))) + case SmemSwizzleBase::SWIZZLE_BASE_32B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B; + case SmemSwizzleBase::SWIZZLE_BASE_64B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_64B; + #endif + } + } +} + +inline CUtensorMapFloatOOBfill +to_CUtensorMapFloatOOBfill(OOBFill const& t) { + switch(t) { + default: throw std::runtime_error("Unknown OOBFill!"); + case OOBFill::ZERO: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + case OOBFill::CONSTANT: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; + } +} + +inline CUtensorMapL2promotion +to_CUtensorMapL2promotion(L2Promotion const& t) { + switch(t) { + default: throw std::runtime_error("Unknown L2Promotion!"); + case L2Promotion::DISABLE: return CU_TENSOR_MAP_L2_PROMOTION_NONE; + case L2Promotion::B64: return CU_TENSOR_MAP_L2_PROMOTION_L2_64B; + case L2Promotion::B128: return CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + case L2Promotion::B256: return CU_TENSOR_MAP_L2_PROMOTION_L2_256B; + } +} + +#endif // !defined(__CUDACC_RTC__) + +#endif // (__CUDACC_VER_MAJOR__ >= 12) + +} // end namespace TMA + +#if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + using TmaDescriptor = CUtensorMap; + using Im2ColTmaDescriptor = CUtensorMap; +#else + using TmaDescriptor = struct alignas(64) { char bytes[128]; }; + using Im2ColTmaDescriptor = struct alignas(64) { char bytes[128]; }; +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Initiates a TensorMap Prefetch +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +prefetch_tma_descriptor(TmaDescriptor const* desc_ptr) +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param) + asm volatile ( + "prefetch.tensormap [%0];" + : + : "l"(gmem_int_desc) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a TensorMap modification (by each field) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Replace tensor pointer directly in GMEM +CUTE_HOST_DEVICE +void +tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr, + void const* const new_tensor_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); + asm volatile ( + "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;" + :: "l"(gmem_int_desc), "l"(new_desc_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +// Replace tensor pointer by bringing the tensormap from GMEM into the shared memory +CUTE_HOST_DEVICE +void +tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc, + void const* const new_tensor_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); + asm volatile ( + "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" + :: "r"(smem_int_desc), "l"(new_desc_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +// Replace tensor dims and strides for GEMMs by bringing the tensormap from GMEM into the shared memory +CUTE_HOST_DEVICE +void +tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor & smem_desc, + cute::array const& prob_shape, + cute::array const& prob_stride) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + uint64_t const smem_int64_desc = 0; + asm volatile ( + "cvt.u64.u32 %0, %1;" + :: "l"(smem_int64_desc), "r"(smem_int_desc)); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[0])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[1])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[2])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 3, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[3])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 4, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[4])); + // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5))) + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[1])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[2])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[3])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[4])); + #else + // 4 LSBs are not included + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[3] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[4] >> 4)); + #endif +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a fused copy and fence operation (needed when modifying tensormap in shared memory) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescriptor& smem_desc) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + asm volatile ( + "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;" + :: "l"(gmem_int_desc), "r"(smem_int_desc)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a release fence operation (needed when modifying tensormap directly in GMEM) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_fence_release() +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + asm volatile ("fence.proxy.tensormap::generic.release.gpu;"); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a acquire fence operation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" + : + : "l"(gmem_int_desc) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +/////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_tma.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ec1564499ec85d8ae3aa5b8a687ac37520b2ed7f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_tma.hpp @@ -0,0 +1,1496 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include // CUTE_ARCH_TMA_SMxx_ENABLED +#include +#include +#include "cutlass/arch/synclog.hpp" + +namespace cute +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#endif +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.1d.L2.global" + " [%0, {%1}];" + : + : "l"(gmem_int_desc), + "r"(crd0) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#endif +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.2d.L2.global" + " [%0, {%1, %2}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#endif +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.3d.L2.global" + " [%0, {%1, %2, %3}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#endif +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global" + " [%0, {%1, %2, %3, %4}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#endif +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.5d.L2.global" + " [%0, {%1, %2, %3, %4, %5}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_1D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_2D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_5D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_1D::PREFETCH::copy(desc_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_2D::PREFETCH::copy(desc_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_3D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_4D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_5D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2, crd3, crd4); + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD im2col: Initiates a TMA copy, in im2col mode, from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5}], [%2], {%6};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.3d.L2.global.im2col" + " [%0, {%1, %2, %3}], {%4};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global.im2col" + " [%0, {%1, %2, %3, %4}], {%5, %6};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.5d.L2.global.im2col" + " [%0, {%1, %2, %3, %4, %5}], {%6, %7, %8};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_3D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_4D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_5D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_3D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_4D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_5D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4}], [%2], %3, %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5}], [%2], %3, %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6}], [%2], %3, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7}], [%2], %3, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST im2col: Initiates a TMA copy, in im2col mode, from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_STORE : Initiates a TMA copy from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_STORE_1D::copy(desc_ptr, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_STORE_2D::copy(desc_ptr, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_STORE_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_STORE_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_STORE_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_STORE im2col: Initiates a TMA copy, in im2col mode, from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_3D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_n); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_4D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_n); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_5D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_d, coord_n); + } +}; + +// Fence for smem stores for subsequent TMA_STORE +CUTE_HOST_DEVICE static void +tma_store_fence() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); + asm volatile ("fence.proxy.async.shared::cta;"); +#elif defined(__CUDA_ARCH__) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +// Indicate arrival of warp issuing TMA_STORE +CUTE_HOST_DEVICE static void +tma_store_arrive() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_tma_store_arrive(__LINE__); + asm volatile("cp.async.bulk.commit_group;"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + + +CUTE_HOST_DEVICE static void +tma_desc_commit_group() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile("cp.async.bulk.commit_group;"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + + +// Wait until at most Count committed TMA_STOREs are pending and all prior commits are complete +template +CUTE_HOST_DEVICE static void +tma_store_wait() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile( + "cp.async.bulk.wait_group.read %0;" + : + : "n"(Count) + : "memory"); + cutlass::arch::synclog_emit_tma_store_wait(__LINE__, Count); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + + +// Wait until all TMA descriptor previously issued are safe to be modified after tma_desc_commit_group() +CUTE_HOST_DEVICE static void +tma_desc_wait_group() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile( + "cp.async.bulk.wait_group.read %0;" + : + : "n"(0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_REDUCE_ADD : Initiates a TMA reduce-add from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_REDUCE_ADD_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group [%0, {%2}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.bulk_group [%0, {%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_REDUCE_ADD_1D::copy(desc_ptr, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_REDUCE_ADD_2D::copy(desc_ptr, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_REDUCE_ADD_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_REDUCE_ADD_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_REDUCE_ADD_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// BULK_COPY : Copy a bulk of memory between shared memory and global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_BULK_COPY_G2S +{ + CUTE_HOST_DEVICE static void + copy(void const* gmem_ptr, uint64_t* mbar_ptr, + void * smem_ptr, int32_t load_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(smem_int_ptr), "l"(gmem_ptr), "r"(load_bytes), "r"(smem_int_mbar) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* gmem_ptr, int32_t load_bytes) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile("cp.async.bulk.prefetch.L2.global [%0], %1;\n" + : + : "l"(gmem_ptr), "r"(load_bytes) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_BULK_COPY_S2G +{ + CUTE_HOST_DEVICE static void + copy(void const* smem_ptr, + void * gmem_ptr, int32_t store_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_BULK_COPY_AUTO {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b97f50ab4a9a9faaa6e1065c65938ccfc8cfa59 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::fma +#include // cute::fma + +namespace cute +{ + +// +// Direct FMA for any type +// + +template +struct UniversalFMA +{ + using DRegisters = D[1]; + using ARegisters = A[1]; + using BRegisters = B[1]; + using CRegisters = C[1]; + + CUTE_HOST_DEVICE static constexpr void + fma(D & d, + A const& a, + B const& b, + C const& c) + { + // Forward to an ADL/cute free function for these types + using cute::fma; + fma(d, a, b, c); + } +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..749da8167e68d19a7c332d771da71d9172e49fd6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100.hpp @@ -0,0 +1,83 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// +// + +#pragma once + +#include +#include + +#include + +namespace cute { + +struct SM100_2x1x1_F32F32F32F32 { + using DRegisters = float2[1]; + using ARegisters = float2[1]; + using BRegisters = float[1]; + using CRegisters = float2[1]; + + CUTE_HOST_DEVICE static void + fma(float2 & d01, + float2 const& a01, + float const& b0, + float2 const& c01) + { +#if defined(CUTE_ARCH_FFMA2_SM100_ENABLED) + cute::fma(d01, a01, make_float2(b0, b0), c01); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_2x1x1_F32F32F32F32 without CUTE_ARCH_FLOAT2_MATH_ENABLED"); +#endif + } +}; + +struct SM100_1x2x1_F32F32F32F32 { + using DRegisters = float2[1]; + using ARegisters = float[1]; + using BRegisters = float2[1]; + using CRegisters = float2[1]; + + CUTE_HOST_DEVICE static void + fma(float2 & d01, + float const& a0, + float2 const& b01, + float2 const& c01) + { +#if defined(CUTE_ARCH_FFMA2_SM100_ENABLED) + cute::fma(d01, make_float2(a0, a0), b01, c01); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_1x2x1_F32F32F32F32 without CUTE_ARCH_FFMA2_SM100_ENABLED"); +#endif + } +}; + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_desc.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_desc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d255c41f179b9f76da96f093e2feef131c17cd87 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_desc.hpp @@ -0,0 +1,651 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include + +#include + +#include +#include // cute::array + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// UMMA Descriptor and utilities + +// UMMA enums and utilities +namespace UMMA +{ + +enum class Major : uint8_t { + K = 0, + MN = 1 +}; + +enum class ScaleIn : uint8_t { + One = 0, + Neg = 1 +}; + +enum class ScaleOut : uint8_t { + Zero = 0, + One = 1 +}; + +enum class Saturate : uint8_t { + False = 0, + True = 1 +}; + +enum class LayoutType : uint8_t { + SWIZZLE_NONE = 0, + SWIZZLE_128B_BASE32B = 1, + SWIZZLE_128B = 2, + SWIZZLE_64B = 4, + SWIZZLE_32B = 6 +}; + +CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { + switch (t) { + case LayoutType::SWIZZLE_NONE: return "SWIZZLE_NONE"; + case LayoutType::SWIZZLE_128B_BASE32B: return "SWIZZLE_128B_BASE32B"; + case LayoutType::SWIZZLE_128B: return "SWIZZLE_128B"; + case LayoutType::SWIZZLE_64B: return "SWIZZLE_64B"; + case LayoutType::SWIZZLE_32B: return "SWIZZLE_32B"; + } + return nullptr; +} + +union SmemDescriptor +{ + uint64_t desc_ = 0; + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + uint16_t stride_byte_offset_ : 14, version_ : 2; // 14 bits [0,14), 2 bits [14,16) + // base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53). + uint8_t : 1, base_offset_ : 3, lbo_mode_ : 1, : 3; // 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused + // layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0, SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4, SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5, N/A = 7 + uint8_t : 5, layout_type_ : 3; // 6 bits unused, 3 bits [5,8) + }; + // Seperate the field, as we may only update one part of desc + struct { + uint32_t lo; + uint32_t hi; + }; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr + operator uint64_t() const noexcept { return desc_; } +}; + +enum class F16F32Format : uint8_t { + F16 = 0, + BF16 = 1, + TF32 = 2, +}; + +CUTE_HOST_DEVICE char const* to_string(F16F32Format const& t) { + switch (t) { + case F16F32Format::F16: return "F16"; + case F16F32Format::BF16: return "BF16"; + case F16F32Format::TF32: return "TF32"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr F16F32Format to_F16F32Format() { + if constexpr (is_same_v) { return F16F32Format::F16; } else + if constexpr (is_same_v) { return F16F32Format::BF16; } else + if constexpr (is_same_v) { return F16F32Format::TF32; } else + { static_assert(sizeof(T) == 0, "Unknown type for F16F32Format"); } +} + +enum class S8Format : uint8_t { + UINT8 = 0, + INT8 = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(S8Format const& t) { + switch (t) { + case S8Format::UINT8: return "UINT8"; + case S8Format::INT8: return "INT8"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr S8Format to_S8Format() { + if constexpr (is_same_v) { return S8Format::UINT8; } else + if constexpr (is_same_v) { return S8Format::INT8; } else + { static_assert(sizeof(T) == 0, "Unknown type for S8Format"); } +} + +enum class MXF8F6F4Format : uint8_t { + E4M3 = 0, + E5M2 = 1, + E2M3 = 3, + E3M2 = 4, + E2M1 = 5, + INVALID = 7 // an invalid datatype for runtime proxy type +}; + +CUTE_HOST_DEVICE char const* to_string(MXF8F6F4Format const& t) { + switch (t) { + case MXF8F6F4Format::E4M3: return "E4M3"; + case MXF8F6F4Format::E5M2: return "E5M2"; + case MXF8F6F4Format::E2M3: return "E2M3"; + case MXF8F6F4Format::E3M2: return "E3M2"; + case MXF8F6F4Format::E2M1: return "E2M1"; + case MXF8F6F4Format::INVALID: return "INVALID"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr MXF8F6F4Format to_MXF8F6F4Format() { + if constexpr (is_same_v) { return MXF8F6F4Format::E4M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E5M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M1; } else + { static_assert(sizeof(T) == 0, "Unknown type for MXF8F6F4Format"); } +} + +enum class MXF4Format : uint8_t { + E2M1 = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(MXF4Format const& t) { + switch (t) { + case MXF4Format::E2M1: return "E2M1"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr MXF4Format to_MXF4Format() { + if constexpr (is_same_v) { return MXF4Format::E2M1; } else + { static_assert(sizeof(T) == 0, "Unknown type for MXF4Format"); } +} + +enum class ScaleFormat : uint8_t { + UE4M3 = 0, + UE8M0 = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(ScaleFormat const& t) { + switch (t) { + case ScaleFormat::UE4M3: return "UE4M3"; + case ScaleFormat::UE8M0: return "UE8M0"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr ScaleFormat to_ScaleFormat() { + if constexpr (is_same_v) { return ScaleFormat::UE4M3; } else + if constexpr (is_same_v) { return ScaleFormat::UE8M0; } else + { static_assert(sizeof(T) == 0, "Unknown type for ScaleFormat"); } +} + +enum class CFormat : uint8_t { + F16 = 0, + F32 = 1, + S32 = 2, +}; + +CUTE_HOST_DEVICE char const* to_string(CFormat const& t) { + switch (t) { + case CFormat::F16: return "F16"; + case CFormat::F32: return "F32"; + case CFormat::S32: return "S32"; + } + return nullptr; +} + +enum class MaxShift : uint8_t { + NoShift = 0, + MaxShift8 = 1, + MaxShift16 = 2, + MaxShift32 = 3 +}; + +enum class BMatrixBufferId : uint8_t { + Zero = 0u, + One = 1u, + Two = 2u, + Three = 3u +}; + +enum class BMatrixBufferReuse : uint8_t { + Keep = 1u, + Reuse = 2u, + ReuseAndKeep = 3u +}; + +// using MaskAndShiftB = uint32_t[2]; +union MaskAndShiftB +{ + uint32_t uri[2]; + + struct { + // Bitfield implementation avoids the need for shifts in assignment + uint8_t start_count_ [4]; // bit [ 0:32) : 8 bits each. Specifies the start count for mask generation. + uint32_t first_span_ : 4, // bit [32:36) : 1 bit each. 0 = start where B is used. 1 = start with where B is skipped(0 value is used). + : 3, // + nzm_ : 1, // bit [39:40) : 0 = Enable the mask. 1 = Disable the mask. + skip_span_ : 8, // bit [40:48) : Count-1 (zero encoded in this field specifies use span of 1) of consecutive columns where 0 value is used. + use_span_ : 8, // bit [48:55) : Count-1 (zero encoded in this field specifies use span of 1) of consecutive columns where B matrix data is used. + shift_ : 6, // bit [56:62) : Shift value for B matrix data. + : 2; + }; +}; + +template +CUTE_HOST_DEVICE constexpr auto +make_column_zero_mask(ShapeType conv_q, int32_t cta_coord_q, int32_t num_pixels_skip_left) { + + static_assert(cute::is_same_v || cute::is_integral::value); + + cute::array column_zero_masks{}; + + static_assert(FLT_S == 3, "Filter size not supported."); + constexpr int MAX_USE_SPAN_COUNT = 256; + constexpr int MAX_SKIP_SPAN_COUNT = 256; + + // conv_q_int used for non-divmod case (add/minus/..) + // conv_q used for divmod case (div/mod/...) + int32_t conv_q_int = int(conv_q); + auto [_, cta_q] = divmod(cta_coord_q * CTA_N, conv_q); + + int step_q = CTA_M == 128 ? CTA_N / 1 + : CTA_M == 64 ? CTA_N / 2 + : CTA_M == 32 ? CTA_N / 4 + : 0; + + for (int mask_iter = 0; mask_iter < int(CTA_N / step_q); ++mask_iter) { + + for (int s_iter = 0; s_iter < FLT_S; s_iter += 1) { + + int32_t skip_span{0}, use_span{0}, nzm{1}, first_span{0}, start_count{0}, shift{0}; + + shift = s_iter; + + // Examples for CZM setting + // CASE0: (skip_span_ < 0) + // | padding |<- conv_q ->| + // |skip_span_|<- use_span ->|skip_span_| + // -skip_span 0 ^cta_q conv_q-1 + // 0 ^index + // + // CASE1: (skip_span_ > 0) + // |<- conv_q ->| + // |skip_span_|<- use_span ->|skip_span_| + // 0 ^cta_q conv_q-1 + // 0 ^index + // + // line 0 an input vector from 0 to conv_q with the padding + // line 1 shows the different spans we need to skip or load + // lines 2-3 show the different coordinates of different boundaries. + // CTQ_q is the coordinate of the present cta. + + int32_t skip_span_ = num_pixels_skip_left - shift; + int32_t index{0}; + if (skip_span_ > 0) { + auto [_, index_mod] = divmod(cta_q, conv_q); + index = index_mod; + } else if (skip_span_ < 0) { + auto [_, index_mod] = divmod((cta_q - skip_span_), conv_q); + index = index_mod; + } else { + nzm = 0; + } + skip_span = cute::max(cute::abs(skip_span_), 1); + use_span = cute::min(conv_q_int - static_cast(skip_span), MAX_USE_SPAN_COUNT); + if (use_span > 0) { + first_span = index >= skip_span ? 0 : 1; + if ((first_span == 0) && (index + CTA_N < conv_q_int + skip_span)) { + nzm = 0; + } else { + start_count = first_span == 0 ? (use_span - (conv_q_int - index)) : index; + } + } else { + skip_span = MAX_SKIP_SPAN_COUNT; + use_span = 1; + first_span = 1; + start_count = 0; + } + + column_zero_masks[s_iter].start_count_[mask_iter] = start_count; + column_zero_masks[s_iter].first_span_ |= first_span << mask_iter; + column_zero_masks[s_iter].nzm_ |= nzm; + column_zero_masks[s_iter].skip_span_ = skip_span - 1; + column_zero_masks[s_iter].use_span_ = use_span - 1; + column_zero_masks[s_iter].shift_ = shift; + + } + + cta_q += step_q; + } + + return column_zero_masks; +} + +template +CUTE_HOST_DEVICE constexpr auto to_UMMAFormat() { + if constexpr (is_same_v) { return F16F32Format::F16; } else + if constexpr (is_same_v) { return F16F32Format::BF16; } else + if constexpr (is_same_v) { return F16F32Format::TF32; } else + if constexpr (is_same_v) { return S8Format::UINT8; } else + if constexpr (is_same_v) { return S8Format::INT8; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + + if constexpr (is_same_v) { return MXF8F6F4Format::E4M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E5M2; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M1; } else + if constexpr (is_same_v) { return MXF4Format::E2M1; } else + { static_assert(sizeof(T) == 0, "Unknown type for UMMAFormat"); } +} + +template +CUTE_HOST_DEVICE constexpr CFormat to_CFormat() { + if constexpr (is_same_v) { return CFormat::F16; } else + if constexpr (is_same_v) { return CFormat::F32; } else + if constexpr (is_same_v) { return CFormat::S32; } else + { static_assert(sizeof(T) == 0, "Unknown type for CFormat"); } +} + +union InstrDescriptor +{ + uint32_t desc_; + + struct { + // Bitfield implementation avoids the need for shifts in assignment + uint16_t sparse_id2_ : 2, // bit [ 0, 2) : Sparse meta data id2 + sparse_flag_ : 1, // bit [ 2, 3) : 0 = dense. 1 = sparse. 1 value valid only for F32F16/S8/MXF8F6F4 + saturate_ : 1, // bit [ 3, 4) : 0 = no saturate. 1 = saturate. 1 value valid only for S8 + c_format_ : 2, // bit [ 4, 6) : 0 = F16. 1 = F32, 2 = S32 + : 1, // + a_format_ : 3, // bit [ 7,10) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. Boolean MMA: 0 Boolean + b_format_ : 3, // bit [10,13) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. Boolean MMA: 0 Boolean + a_negate_ : 1, // bit [13,14) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + b_negate_ : 1, // bit [14,15) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + a_major_ : 1; // bit [15,16) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + uint16_t b_major_ : 1, // bit [16,17) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + n_dim_ : 6, // bit [17,23) : 3 LSBs not included. Valid values range from 1 (N=8) to 32 (N=256). All values are not valid for all instruction formats + : 1, // + m_dim_ : 5, // bit [24,29) : 4 LSBs not included. Valid values are: 4 (M=64), 8 (M=128), 16 (M=256) + : 1, // + max_shift_ : 2; // bit [30,32) : Maximum shift for WS instruction. Encoded as follows: 0 = no shift, 1 = maximum shift of 8, 2 = maximum shift of 16, 3 = maximum shift of 32. + }; + + // Decay to a uint32_t + CUTE_HOST_DEVICE constexpr explicit + operator uint32_t() const noexcept { return desc_; } +}; + +union InstrDescriptorBlockScaled +{ + uint32_t desc_; + + struct { + // Bitfield implementation avoids the need for shifts in assignment + uint16_t sparse_id2_ : 2, // bit [ 0, 2) : Sparse meta data id2 + sparse_flag_ : 1, // bit [ 2, 3) : 0 = dense. 1 = sparse. 1 value valid only for F32F16/S8/MXF8F6F4 + : 1, // + b_sf_id_ : 2, // bit [ 4, 6) : Matrix B Scale Factor ID + : 1, // + a_format_ : 3, // bit [ 7, 9) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean + b_format_ : 3, // bit [10,12) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean + a_negate_ : 1, // bit [13,14) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + b_negate_ : 1, // bit [14,15) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + a_major_ : 1; // bit [15,16) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + uint16_t b_major_ : 1, // bit [16,17) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + n_dim_ : 6, // bit [17,23) : 3 LSBs not included. Valid values range from 1 (N=8) to 32 (N=256). All values are not valid for all instruction formats + scale_format_ : 1, // bit [23,24) : 0=E4M3, 1=E8M0 + m_dim_ : 5, // bit [24,29) : 4 LSBs not included. Valid values are: 4 (M=64), 8 (M=128), 16 (M=256) + a_sf_id_ : 2, // bit [29,31) : Matrix A Scale Factor ID + k_size_ : 1; // bit [31,32) : MMA-K Dim. MXF8F6F4Format: 0=[dense: K32, sparse: K64]. S8Format: 0=[dense: K32, sparse: invalid]. MXF4Format: 0=[dense: K64, sparse: K128], 1=[dense: K96, sparse: invalid]. + }; + + // Decay to a uint32_t + CUTE_HOST_DEVICE constexpr + operator uint32_t() const noexcept { return desc_; } +}; + +template +CUTE_HOST_DEVICE constexpr +UMMA::InstrDescriptor +make_instr_desc() +{ + UMMA::InstrDescriptor desc_i = {}; + + desc_i.a_format_ = uint8_t(UMMA::to_UMMAFormat()); + desc_i.b_format_ = uint8_t(UMMA::to_UMMAFormat()); + desc_i.c_format_ = uint8_t(UMMA::to_CFormat()); + + desc_i.m_dim_ = (M >> 4); + desc_i.n_dim_ = (N >> 3); + + desc_i.a_major_ = uint8_t(a_major); + desc_i.b_major_ = uint8_t(b_major); + + desc_i.a_negate_ = uint8_t(a_neg); + desc_i.b_negate_ = uint8_t(b_neg); + desc_i.saturate_ = uint8_t(c_sat); + + desc_i.sparse_flag_ = is_sparse; // 1 = Sparse + desc_i.sparse_id2_ = 0; + + desc_i.max_shift_ = uint8_t(max_shift); + + return desc_i; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc(uint16_t sparse_id2 = 0u, uint32_t tmem_e = 0u) { + UMMA::InstrDescriptor desc_i = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat, is_sparse, + max_shift>(); + + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc(UMMA::InstrDescriptor desc_i, uint16_t sparse_id2 = 0u, uint32_t tmem_e = 0u) +{ + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +template +CUTE_HOST_DEVICE constexpr +UMMA::InstrDescriptorBlockScaled +make_instr_desc_block_scaled() +{ + UMMA::InstrDescriptorBlockScaled desc_i = {}; + + desc_i.a_format_ = uint8_t(UMMA::to_UMMAFormat()); + desc_i.b_format_ = uint8_t(UMMA::to_UMMAFormat()); + + desc_i.scale_format_ = uint8_t(UMMA::to_ScaleFormat()); + desc_i.a_sf_id_ = 0; + desc_i.b_sf_id_ = 0; + + desc_i.m_dim_ = (M >> 4); + desc_i.n_dim_ = (N >> 3); + + desc_i.a_major_ = uint8_t(a_major); + desc_i.b_major_ = uint8_t(b_major); + + desc_i.a_negate_ = uint8_t(a_neg); + desc_i.b_negate_ = uint8_t(b_neg); + desc_i.sparse_flag_ = is_sparse; // 1 = Sparse + desc_i.sparse_id2_ = 0; + + // Below would bring some warnings. +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Wconversion" +#endif + return desc_i; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc_block_scaled(uint32_t const tmem_sfa_addr, uint32_t const tmem_sfb_addr, + uint16_t const sparse_id2 = 0u, uint32_t const tmem_e = 0u) +{ + UMMA::InstrDescriptorBlockScaled desc_i = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, + a_major, b_major, + a_neg, b_neg, + is_sparse>(); + + // The first 2-bits of TMEM address includes byte address. + desc_i.a_sf_id_ = (tmem_sfa_addr & 0xC0000000) >> 30; + desc_i.b_sf_id_ = (tmem_sfb_addr & 0xC0000000) >> 30; + + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc_block_scaled(UMMA::InstrDescriptorBlockScaled desc_i, + uint32_t const tmem_sfa_addr, uint32_t const tmem_sfb_addr, + uint16_t const sparse_id2 = 0u, uint32_t const tmem_e = 0u) +{ + // The first 2-bits of TMEM address includes byte address. + desc_i.a_sf_id_ = (tmem_sfa_addr & 0xC0000000) >> 30; + desc_i.b_sf_id_ = (tmem_sfb_addr & 0xC0000000) >> 30; + + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +} // end namespace UMMA +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_umma.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_umma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..25b504dc0b2bfc36f494e70b78801a5b10c3c553 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_umma.hpp @@ -0,0 +1,1777 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cute +{ + +template +struct SM100_MMA_TF32_SS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_TF32 N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_SS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_SS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_F16BF16 N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_TF32 N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + uint32_t mask[4] = {0, 0, 0, 0}; + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_TS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16 N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + uint32_t mask[4] = {0, 0, 0, 0}; + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_SS_SCALED +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_SS_SCALED M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_F16BF16_SS_SCALED N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p, %9; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_SS without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_TS_SCALED +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_TS_SCALED M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16_TS_SCALED N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_TS_SCALED A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p, %9; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_TS_SCALED without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_SS_SPARSE +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_TF32_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::tf32 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_SS_SPARSE without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_SS_SPARSE +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_F16BF16_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::f16 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_SS_SPARSE without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::tf32 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_SS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_TS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32_2x1SM_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::tf32 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_TS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_SS_SCALED +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_SCALED M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_SCALED N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p, %13; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_SCALED without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_TS_SCALED +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_SCALED M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_SCALED N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_SCALED A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p, %13; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_SCALED without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_SS_SPARSE M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_SS_SPARSE N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::tf32 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_SPARSE M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_SPARSE N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::f16 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_SS +{ + static_assert(is_same_v, "SM100_MMA_S8_SS result type can only be int32_t."); + static_assert(M == 64 || M == 128, "SM100_MMA_S8_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 8 || ((N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_S8_SS N-mode size should be 8 or a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::i8 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_SS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_S8_TS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 8 || ((N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_S8_TS N-mode size should be 8 or a multiple of 16 between 16 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_S8_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::i8 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_TS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_SS_SPARSE +{ + static_assert(is_same_v, "SM100_MMA_S8_SS_SPARSE result type can only be int32_t."); + static_assert(M == 64 || M == 128, "SM100_MMA_S8_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 8 || ((N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_S8_SS_SPARSE N-mode size should be 8 or a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::i8 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_SS_SPARSE without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_S8_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_S8_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::i8 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_SS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_S8_2x1SM_TS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_S8_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_S8_2x1SM_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::i8 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_TS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_S8 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::i8 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +struct SM100_MMA_F8F6F4_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_SS +{ + static_assert(M == 128, "SM100_MMA_MXF8F6F4_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F8F6F4_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_TS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F8F6F4_TS N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F8F6F4_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_TS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F8F6F4_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_TS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F8F6F4_2x1SM_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_TS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F8F6F4_SS_SPARSE +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F8F6F4_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; // %5, %6, %7, %8 + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::f8f6f4 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_SS_SPARSE +{ + static_assert(M == 128, "SM100_MMA_MXF8F6F4_SS_SPARSE M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +struct SM100_MMA_F8F6F4_2x1SM_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_2x1SM_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE +{ + static_assert(M == 256, "SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE M-mode size should be 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_2x1SM_SS +{ + static_assert(M == 256, "SM100_MMA_MXF8F6F4_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_2x1SM_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F8F6F4_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS_SPARSE M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS_SPARSE N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::f8f6f4 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF4_SS +{ + static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF4_SS N-mode size should be a multiple of 8 between 8 and 256."); + static_assert((VS == 16) || (VS == 32), "SM100_MMA_MXF4_SS Vector size can only be 16 or 32."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { + if constexpr (VS == 16) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_SS (VS = 16) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_SS (VS = 32) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } + +}; + +template +struct SM100_MMA_MXF4NVF4_SS_SPARSE +{ + static_assert(M == 128, "SM100_MMA_MXF4NVF4_SS_SPARSE M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF4NVF4_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.sp.cta_group::1.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.sp.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_SS_SPARSE (VS = 32) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + + if constexpr (VS == 64) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.sp.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.sp.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_SS_SPARSE (VS = 64) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } +}; + +template +struct SM100_MMA_MXF4_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_MXF4_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF4_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256."); + static_assert((VS == 16) || (VS == 32), "SM100_MMA_MXF4_2x1SM_SS Vector size can only be 16 or 32."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { + if constexpr (VS == 16) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_2x1SM_SS (VS = 16) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_2x1SM_SS (VS = 32) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } +}; + +template +struct SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE +{ + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE N-mode size should be a multiple of 16 between 16 and 256."); + static_assert((VS == 32) || (VS == 64), "SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE Vector size can only be 32 or 64."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.sp.cta_group::2.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.sp.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE (VS = 32) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + + if constexpr (VS == 64) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.sp.cta_group::2.kind::mxf4.block_scale.block32 [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.sp.cta_group::2.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE (VS = 64) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } +}; + +namespace SM103 { +template +struct SM103_MXF4_ULTRA_SS_VS +{ + static_assert(M == 128, "MMA M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(((VS == 32) & (is_same_v && is_same_v)) || (VS == 16), + "Vector size can only be 4x mode (VS=16) or 2x mode (VS=32) for MMA. 2x mode only supports float_e2m1_t for a/b types and ue8m0_t for sf type"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED) + if constexpr (VS == 16) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } + else if constexpr (VS == 32) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM103_MXF4_ULTRA_SS_VS without CUTE_ARCH_MMA_SM103A_ENABLED"); +#endif + } + +}; + + +template +struct SM103_MXF4_ULTRA_2x1SM_SS_VS +{ + static_assert(M == 128 || M == 256, "MMA M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(((VS == 32) & (is_same_v && is_same_v)) || (VS == 16), + "Vector size can only be 4x mode (VS=16) or 2x mode (VS=32) for MMA. 2x mode only supports float_e2m1_t for a/b types and ue8m0_t for sf type"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED) + if constexpr (VS == 16) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } + else if constexpr (VS == 32) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM103_MXF4_ULTRA_2x1SM_SS_VS without CUTE_ARCH_MMA_SM103A_ENABLED"); +#endif + } + +}; +} // namespace SM103 + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1433a2c8d0b558ab12d2d6a0fbe028d77a63e689 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120.hpp @@ -0,0 +1,3254 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include +#include +#include // cute::float_e4m3_t, etc +#include + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_16x8x32_TN +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_16x8x32_TN for given data types."); +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MMA 16x8x32 TN E3M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +// MMA 16x8x32 TN E2M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +// MMA 16x8x32 TN E4M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +// MMA 16x8x32 TN E5M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +// MMA.F16 16x8x32 TN E2M1 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +// MMA.F16 16x8x32 TN E4M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM120::BLOCKSCALED { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_16x8x32_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_16x8x32_TN_VS for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_16x8x64_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_16x8x64_TN_VS for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)), "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x64 TN E2M1 x E2M1 with SF UE8M0 +template +struct SM120_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + static constexpr int SFBits = (VS == 32) ? 16 : 32; + using RegTypeSF = uint_bit_t; + + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + RegTypeSF const& sfa0, + RegTypeSF const& sfb0) + { + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + CUTE_STATIC_ASSERT(VS == 16 || VS == 32, "Scaling factor vector size has to be 16 or 32 for MXF4NVF4 MMA."); + +#if defined(CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x64_TN_VS without CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +// MMA.SF 16x8x64 TN E2M1 x E2M1 with SF E4M3 +template +struct SM120_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + static constexpr int SFBits = (VS == 32) ? 16 : 32; + using RegTypeSF = uint_bit_t; + + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + RegTypeSF const& sfa0, + RegTypeSF const& sfb0) + { + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 16 || VS == 32, "Scaling factor vector size has to be 16 or 32 for MXF4NVF4 MMA."); +#if defined(CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x64_TN_VS without CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED"); +#endif + + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM120::BLOCKSCALED + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class ElementB, + class ElementC +> +CUTE_HOST_DEVICE constexpr +auto +rr_op_selector_sm120() +{ + return SM120_16x8x32_TN{}; +} + +template < + class ElementA, + class ElementB, + class ElementC, + class ElementSF, + int SFVecSize, + bool UseF8F6F4 +> +CUTE_HOST_DEVICE constexpr +auto +rr_blockscaled_op_selector_sm120() +{ + if constexpr (UseF8F6F4) { + return SM120::BLOCKSCALED::SM120_16x8x32_TN_VS{}; + } + else{ + return SM120::BLOCKSCALED::SM120_16x8x64_TN_VS{}; + } +} + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120_sparse.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120_sparse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a695003216d2c7553ab4b62ddd12594eabc576f2 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120_sparse.hpp @@ -0,0 +1,3444 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include // cute::float_e4m3_t, etc +#include +#include + +namespace cute { + +namespace SM120::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_SPARSE_16x8x64_TN +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_SPARSE_16x8x64_TN for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +} // end namespace SM120::SPARSE + +namespace SM120::BLOCKSCALED::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_SPARSE_16x8x64_TN_VS for given data types."); +}; + +template +struct SM120_SPARSE_16x8x128_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_SPARSE_16x8x128_TN_VS for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x128 TN E2M1 x E2M1 with SF UE8M0 +template +struct SM120_SPARSE_16x8x128_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + static constexpr int SFBits = (VS == 64) ? 16 : 32; + using RegTypeSF = uint_bit_t; + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, RegTypeSF const& sfa, RegTypeSF const& sfb) + { + + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64 || VS == 32, "Scaling factor vector size has to be 64 or 32 for MXF4NVF4."); +#if defined(CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale.scale_vec::2X.m16n8k128.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}, {%18, %19}," + "{%20}, {%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"(uint32_t(sfa)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb)), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::SPARSE::SM120_SPARSE_16x8x128_TN_VS without CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x128 TN E2M1 x E2M1 with SF E4M3 +template +struct SM120_SPARSE_16x8x128_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + static constexpr int SFBits = (VS == 64) ? 16 : 32; + using RegTypeSF = uint_bit_t; + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, RegTypeSF const& sfa, RegTypeSF const& sfb) + { +#if defined(CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for NVF4 with e2m1 and scale factor e4m3."); + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale.scale_vec::4X.m16n8k128.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}, {%18, %19}," + "{%20}, {%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"(uint32_t(sfa)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb)), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::SPARSE::SM120_SPARSE_16x8x128_TN_VS without CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace SM120::BLOCKSCALED::SPARSE + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class ElementB, + class ElementC +> +CUTE_HOST_DEVICE constexpr +auto +rr_sparse_op_selector_sm120() +{ + // Get MMA SPARSE OP + return SM120::SPARSE::SM120_SPARSE_16x8x64_TN{}; +} + +template < + class ElementA, + class ElementB, + class ElementC, + class ElementSF, + int SFVecSize, + bool UseF8F6F4 +> +CUTE_HOST_DEVICE constexpr +auto +rr_blockscaled_sparse_op_selector_sm120() +{ + if constexpr (UseF8F6F4) { + return SM120::BLOCKSCALED::SPARSE::SM120_SPARSE_16x8x64_TN_VS{}; + } + else { + return SM120::BLOCKSCALED::SPARSE::SM120_SPARSE_16x8x128_TN_VS{}; + } +} + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm61.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm61.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fce9a47f02a982a1c02c438fc582aaec3e99b157 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm61.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) +# define CUTE_ARCH_MMA_SM61_ENABLED +#endif + +namespace cute +{ + +struct SM61_DP4A +{ + using DRegisters = int32_t[1]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = int32_t[1]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) + { +#if defined(CUTE_ARCH_MMA_SM61_ENABLED) + asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED"); +#endif + } +}; + +struct SM61_DP2A +{ + using DRegisters = int32_t[1]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = int32_t[1]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) + { +#if defined(CUTE_ARCH_MMA_SM61_ENABLED) + asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED"); +#endif + } +}; + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm70.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm70.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2acd421dda8f8247201f7c399cae220b3355105f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm70.hpp @@ -0,0 +1,329 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) +# define CUTE_ARCH_MMA_SM70_SUPPORTED +# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_ARCH_MMA_SM70_ENABLED +# endif +#endif + +namespace cute +{ + +// +// SM70 MMA 884 F16F16F16 +// + +struct SM70_8x8x4_F16F16F16F16_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_NT +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_NN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_TT +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// SM70 MMA 884 F16F16F32 +// + +struct SM70_8x8x4_F32F16F16F32_TN +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_NT +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_NN +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_TT +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm75.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm75.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7d34e2b45fbd3352d7b52ed0f62c8fc7134add4c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm75.hpp @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) +# define CUTE_ARCH_MMA_SM75_SUPPORTED +# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +# define CUTE_ARCH_MMA_SM75_ENABLED +# endif +#endif + +namespace cute +{ + +// +// SM75 MMA 1688 F16F16F32 +// + +struct SM75_16x8x8_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_SM75_ENABLED) + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// SM75 MMA 8816 S8S8S32 +// + +struct SM75_8x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM75_ENABLED) + asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32" + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm80.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm80.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e0dbc630760210989eb073576e7ec1693f44f240 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm80.hpp @@ -0,0 +1,2241 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +# define CUTE_ARCH_MMA_SM80_ENABLED + +#if (__CUDA_ARCH__ <= 900) +#define CUTE_ARCH_MMA_B1_AND_SM80_ENABLED +#endif + +#if (__CUDA_ARCH__ <= 890) +#define CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED +#endif + +#endif + + + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F16F16F16F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F16F16F16F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32BF16BF16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F32BF16BF16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct SM80_16x8x4_F32TF32TF32F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x4_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32TF32TF32F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x4 TN +struct SM80_8x8x4_F64F64F64F64_TN +{ + using DRegisters = double[2]; + using ARegisters = double[1]; + using BRegisters = double[1]; + using CRegisters = double[2]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, + double const& a0, + double const& b0, + double const& c0, double const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=d"(d0), "=d"(d1) + : "d"(a0), + "d"(b0), + "d"(c0), "d"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +// MMA 8x8x4 TN with Planar Complex multiplication +struct SM80_8x8x4_C64C64C64C64_TN +{ + using DRegisters = complex[2]; + using ARegisters = complex[1]; + using BRegisters = complex[1]; + using CRegisters = complex[2]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex const& a0, + complex const& b0, + complex const& c0, complex const& c1) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM80_8x8x4_F64F64F64F64_TN::fma( + rd0, rd1, + a0.real(), + b0.real(), + c0.real(), c1.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM80_8x8x4_F64F64F64F64_TN::fma( + id0, id1, + a0.imag(), + b0.real(), + c0.imag(), c1.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM80_8x8x4_F64F64F64F64_TN::fma( + rd0, rd1, + -a0.imag(), + b0.imag(), + d0.real(), d1.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM80_8x8x4_F64F64F64F64_TN::fma( + id0, id1, + a0.real(), + b0.imag(), + d0.imag(), d1.imag()); + } +}; + +// MMA 8x8x4 TN with Gaussian Complex multiplication: +// (a + bi)*(c + di) +// yields +// t0 += a*c +// t1 += b*d +// t2 += (a+b)*(c+d) +// then +// re = t0 - t1 +// im = t2 - t0 - t1 +struct SM80_8x8x4_GC64C64C64GC64_TN +{ + struct GaussComplex { + double t0, t1, t2; + + CUTE_HOST_DEVICE //constexpr + operator complex() const { return complex(t0 - t1, t2 - t0 - t1); } + + CUTE_HOST_DEVICE friend //constexpr + complex operator*(GaussComplex const& a, complex const& b) { return static_cast>(a) * b; } + CUTE_HOST_DEVICE friend //constexpr + complex operator*(complex const& a, GaussComplex const& b) { return b * a; } + + CUTE_HOST_DEVICE friend //constexpr + complex operator+(GaussComplex const& a, complex const& b) { return static_cast>(a) + b; } + CUTE_HOST_DEVICE friend //constexpr + complex operator+(complex const& a, GaussComplex const& b) { return b + a; } + }; + + using DRegisters = GaussComplex[2]; + using ARegisters = complex[1]; + using BRegisters = complex[1]; + using CRegisters = GaussComplex[2]; + + CUTE_HOST_DEVICE static void + fma(GaussComplex & d0, GaussComplex & d1, + complex const& a0, + complex const& b0, + GaussComplex const& c0, GaussComplex const& c1) + { + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t0, d1.t0, + a0.real(), + b0.real(), + c0.t0, c1.t0); + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t1, d1.t1, + a0.imag(), + b0.imag(), + c0.t1, c1.t1); + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t2, d1.t2, + a0.real() + a0.imag(), + b0.real() + b0.imag(), + c0.t2, c1.t2); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8U8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8U8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4S4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4U4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4S4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4U4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x128 TN +struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x128 TN +struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x256 TN +struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x128 TN +struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x128 TN +struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x256 TN +struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm89.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm89.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b6c440944afecc029cea55e29ab79c8d563c7ded --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm89.hpp @@ -0,0 +1,296 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// + +// +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) +# define CUTE_ARCH_MMA_F32_SM89_SUPPORTED +#endif + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) +# define CUTE_ARCH_MMA_F16_SM89_SUPPORTED +#endif + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) +# if defined(CUTE_ARCH_MMA_F32_SM89_SUPPORTED) +# define CUTE_ARCH_MMA_F32_SM89_ENABLED +# endif + +# if defined(CUTE_ARCH_MMA_F16_SM89_SUPPORTED) +# define CUTE_ARCH_MMA_F16_SM89_ENABLED +# endif +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cute { +// MMA 16x8x32 TN +struct SM89_16x8x32_F32E4M3E4M3F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E4M3E4M3F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F32E4M3E5M2F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E4M3E5M2F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F32E5M2E5M2F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E5M2E5M2F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F32E5M2E4M3F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E5M2E4M3F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E4M3E4M3F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E4M3E4M3F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E4M3E5M2F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E4M3E5M2F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E5M2E4M3F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E5M2E4M3F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E5M2E5M2F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E5M2E5M2F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6aa5fba2b2da6d04d3aea259d7b2047c12bd960e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90.hpp @@ -0,0 +1,9331 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Config +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +# define CUTE_ARCH_MMA_SM90_ENABLED +# define CUTE_ARCH_MMA_F64_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +namespace SM90 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct MMA_16x8x4_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[2]; + using BRegisters = double[1]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, + double const& b0, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), + "d"(b0), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct MMA_16x8x8_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[4]; + using BRegisters = double[2]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, double const& a2, double const& a3, + double const& b0, double const& b1, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), "d"(a2), "d"(a3), + "d"(b0), "d"(b1), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct MMA_16x8x16_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[8]; + using BRegisters = double[4]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, double const& a2, double const& a3, + double const& a4, double const& a5, double const& a6, double const& a7, + double const& b0, double const& b1, double const& b2, double const& b3, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7, %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16, %17, %18, %19};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), "d"(a2), "d"(a3), + "d"(a4), "d"(a5), "d"(a6), "d"(a7), + "d"(b0), "d"(b1), "d"(b2), "d"(b3), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct MMA_16x8x4_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[2]; + using BRegisters = complex[1]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& b0, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + MMA_16x8x4_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), + b0.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + MMA_16x8x4_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), + b0.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + MMA_16x8x4_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), + b0.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + MMA_16x8x4_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), + b0.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct MMA_16x8x8_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[4]; + using BRegisters = complex[2]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& a2, complex const& a3, + complex const& b0, complex const& b1, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + MMA_16x8x8_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), a2.real(), a3.real(), + b0.real(), b1.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + MMA_16x8x8_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), a2.imag(), a3.imag(), + b0.real(), b1.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + MMA_16x8x8_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), + b0.imag(), b1.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + MMA_16x8x8_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), a2.real(), a3.real(), + b0.imag(), b1.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct MMA_16x8x16_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[8]; + using BRegisters = complex[4]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& a2, complex const& a3, + complex const& a4, complex const& a5, + complex const& a6, complex const& a7, + complex const& b0, complex const& b1, + complex const& b2, complex const& b3, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + MMA_16x8x16_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), a2.real(), a3.real(), + a4.real(), a5.real(), a6.real(), a7.real(), + b0.real(), b1.real(), b2.real(), b3.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + MMA_16x8x16_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), a2.imag(), a3.imag(), + a4.imag(), a5.imag(), a6.imag(), a7.imag(), + b0.real(), b1.real(), b2.real(), b3.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + MMA_16x8x16_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), + -a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(), + b0.imag(), b1.imag(), b2.imag(), b3.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + MMA_16x8x16_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), a2.real(), a3.real(), + a4.real(), a5.real(), a6.real(), a7.real(), + b0.imag(), b1.imag(), b2.imag(), b3.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include // cute::size +#include // cute::is_static +#include // cute::half_t, cute::float_e4m3_t, cute::tfloat32_t, etc +#include // cute::is_same_v + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { +namespace SM90::GMMA { + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +ss_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F16F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +ss_op_selector_sparse() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector_sparse() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +} // end namespace SM90::GMMA +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_desc.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_desc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5f65746b34f44edc95c6f3af4b284e7fa8a0cd18 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_desc.hpp @@ -0,0 +1,151 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA Descriptor and utilities + +// GMMA enums and utilities +namespace SM90::GMMA { + +enum class LayoutType : uint8_t { + INTERLEAVE = 0, + B128 = 1, + B64 = 2, + B32 = 3, +}; + +CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { + switch (t) { + case LayoutType::INTERLEAVE: return "INTERLEAVE"; + case LayoutType::B128: return "B128"; + case LayoutType::B64: return "B64"; + case LayoutType::B32: return "B32"; + } + return nullptr; +} + +#if !defined(__CUDACC_RTC__) +// Output operator for all enums in this namespace +CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { + char const* s = to_string(t); + if (s) { + std::operator<<(os, s); // Explicit call to avoid ambiguity + } else { + os.setstate(std::ios_base::failbit); + } + return os; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace SM90::GMMA + +union GmmaDescriptor +{ + CUTE_HOST_DEVICE constexpr + GmmaDescriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(uint64_t desc) noexcept : desc_(desc) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(GmmaDescriptor const& t) noexcept : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(GmmaDescriptor && t) noexcept : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor& operator=(GmmaDescriptor const& t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor& operator=(GmmaDescriptor && t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + // For N: This is the stride from the first col to the second col of the 8x2 brick in INTERLEAVED + // Unused for all SWIZZLE_* layouts (and assumed to be 1) + // For T: This is the stride from the first 8 rows to the next 8 rows. + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + // For N: This is the stride from the first 8 rows to the next 8 rows. + // For T: This is the stride fro mthe first 8 cols to the next 8 cols. + uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // base_offset, bit [49,52) + // Valid only for SWIZZLE_128B and SWIZZLE_64B + uint8_t : 1, base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused + // layout type, bit [62,64) + // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) + } bitfield; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr + operator uint64_t() const noexcept { return desc_; } +}; + +// Printer +CUTE_HOST_DEVICE void +print(GmmaDescriptor const& t) +{ +#if !defined(__CUDACC_RTC__) + printf("GmmaDescriptor: 0x%016llx\n", static_cast(t.desc_)); + printf(" start_addr : 0x%04x\n", t.bitfield.start_address_); + printf(" leading_off: 0x%04x (%d)\n", t.bitfield.leading_byte_offset_, t.bitfield.leading_byte_offset_); + printf(" stride_off : 0x%04x (%d)\n", t.bitfield.stride_byte_offset_, t.bitfield.stride_byte_offset_); + printf(" base_offset: 0x%01x\n", t.bitfield.base_offset_); + printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); +#endif // !defined(__CUDACC_RTC__) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c0ce4a2e7198ac59c9c6a013a2b75a352b50a5b6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma.hpp @@ -0,0 +1,20974 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Warpgroup sync primitives + +CUTE_HOST_DEVICE +void +warpgroup_arrive() +{ +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_arrive(__LINE__); + asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +template +CUTE_HOST_DEVICE +void +warpgroup_wait() +{ + static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +// Marks the commit point for one or more sized batch of warpgroup MMAs. +CUTE_HOST_DEVICE +void +warpgroup_commit_batch() +{ +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_commit_batch(__LINE__); + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(uint32_t& reg) { + // MSVC emits a build error for 'asm volatile' + // even if it only occurs in a __device__ function. + // This prevents the error. +#if defined(__CUDA_ARCH__) + asm volatile("" : "+r"(reg) :: "memory"); +#endif +} + +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(float& reg) { +#if defined(__CUDA_ARCH__) + asm volatile("" : "+f"(reg) :: "memory"); +#endif +} + +namespace SM90::GMMA { + +enum class Major { + K = 0, + MN = 1 +}; + +enum class ScaleOut { + Zero = 0, + One = 1 +}; + +enum class ScaleIn { + Neg = -1, + One = 1 +}; + +enum class SparseSel { + Zero = 0, + One = 1 +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA + +} // namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_sm90_gmma_ext.hpp" +#endif diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_ext.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_ext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f697b94912f843b55e86c86849a20eb1625cb097 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_ext.hpp @@ -0,0 +1,56445 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +namespace SM90::GMMA { + +// GMMA 64x24x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd97a3dd64ae07ad21640689eeef0d63030e52e7 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse.hpp @@ -0,0 +1,22745 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // GMMA::Major, etc. + +#include "cutlass/arch/synclog.hpp" + +namespace cute { + +namespace SM90::GMMA::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA::SPARSE + +} // namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_sm90_gmma_sparse_ext.hpp" +#endif diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8945551a1c94801d3a902ce47c57efea0e7cafdf --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp @@ -0,0 +1,60445 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +namespace SM90::GMMA::SPARSE { + +// SPARSE GMMA 64x24x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA::SPARSE + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/simd_sm100.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/simd_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1c07a31e6d3134e1c0437c438a5183081f50815d --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/simd_sm100.hpp @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// + +#pragma once + +#include +#include +#include + +namespace cute { + +CUTE_HOST_DEVICE +void +add(float2 & c, + float2 const& a, + float2 const& b) +{ +#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED) + asm volatile("add.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); +#else + add(c.x, a.x, b.x); + add(c.y, a.y, b.y); +#endif +} + +CUTE_HOST_DEVICE +void +mul(float2 & c, + float2 const& a, + float2 const& b) +{ +#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED) + asm volatile("mul.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); +#else + mul(c.x, a.x, b.x); + mul(c.y, a.y, b.y); +#endif +} + +CUTE_HOST_DEVICE +void +fma(float2 & d, + float2 const& a, + float2 const& b, + float2 const& c) +{ +#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED) + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +#else + fma(d.x, a.x, b.x, c.x); + fma(d.y, a.y, b.y, c.y); +#endif +} + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/tmem_allocator_sm100.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/tmem_allocator_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..680e237f76d158846236c3317668b150fe162098 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/tmem_allocator_sm100.hpp @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +namespace cute::TMEM { + +// +// TMEM Addressing Constants +// + +// 128 DP x 512 COL x uint32_t-addressing +using MAX_CAPACITY_BITS = Int<128*512*32>; + +// TMEM DP stride in bit-addressing (shift by 5 for conversion from uint32_t) +using DP_b = cute::constant; + +// TMEM DP stride in type-T addressing +template +using DP = cute::constant::OffsetShift)>; + +// +// TMEM Allocators +// + +// All operations of this class require that only a single warp uniformly participates +class Allocator1Sm { +public: + static constexpr int ColumnsPerAllocationSlice = 32; + static constexpr int Sm100TmemCapacityColumns = 512; + + __device__ Allocator1Sm() { } + + /** + * Performs a non-blocking allocation of TMEM. + * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. + * @param dst_ptr Pointer to shared memory to which to write the result tmem pointer to. + * @pre Must be issued by a single fully active warp of the CTA. + * @pre Must never be issued by more than one warp at the same time. + * @pre For repeated allocations, the same warp must be used to issue all allocations. + **/ + CUTE_HOST_DEVICE void + allocate(int num_columns, uint32_t* dst_ptr) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr); + asm volatile( + "tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + : + : "r"(dst_intptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + __device__ + void + free(uint32_t tmem_ptr, int num_columns) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile( + "{\n\t" + "tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t" + "}" + : + : "r"(tmem_ptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + __device__ void + release_allocation_lock() { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile("tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;" ::); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +class Allocator2Sm { +public: + static constexpr int ColumnsPerAllocationSlice = 32; + static constexpr int Sm100TmemCapacityColumns = 512; + + __device__ Allocator2Sm() { } + + /** + * Performs a non-blocking allocation of TMEM. + * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. + * @param dst_ptr Pointer to shared memory to which to write the result tmem pointer to. + * Both CTAs _must_ provide the exact same dst_ptr for correctness. + * @pre Must be issued by a single fully active warp of the CTA. + * @pre Must never be issued by more than one warp at the same time. + * @pre For repeated allocations, the same warp must be used to issue all allocations. + * @pre The 2 warps from participating CTAs have the same logical warp ID. + **/ + CUTE_HOST_DEVICE void + allocate(int num_columns, uint32_t* dst_ptr) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr); + asm volatile( + "tcgen05.alloc.cta_group::2.sync.aligned.shared::cta.b32 [%0], %1;" + : + : "r"(dst_intptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + /** + * Frees the TMEM corresponding to the pointer and slice count provided. + * Release the TMEM after checking that the CTA issuing the free does indeed own the corresponding slices. + * @param tmem_ptr Base address of the TMEM address space being freed. + * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. + * @pre Must be issued by a single fully active warp of the CTA. + * @pre Must never be issued by more than one warp at the same time. + * @pre The 2 warps from participating CTAs have the same logical warp ID. + * @returns true + **/ + __device__ + void + free(uint32_t tmem_ptr, int num_columns) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile( + "{\n\t" + "tcgen05.dealloc.cta_group::2.sync.aligned.b32 %0, %1; \n\t" + "}" + : + : "r"(tmem_ptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + __device__ + void + release_allocation_lock() { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile("tcgen05.relinquish_alloc_permit.cta_group::2.sync.aligned;" ::); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } +}; + +} // namespace cute::TMEM diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/util.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6a3883e85a45b41b36bcbdf1784668f93892ec2c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/arch/util.hpp @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#if defined(__clang__) && defined(__CUDA__) + // __cvta_generic_to_shared was added in Clang 14: https://reviews.llvm.org/D111665 + #if __clang_major__ >= 14 + #define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 + #endif + + // __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665 + // ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897 + #if (!defined(_WIN32) && __clang_major__ >= 14) || __clang_major__ >= 15 + #define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER 1 + #endif +#endif + +#if defined(__NVCC__) || defined(__CUDACC_RTC__) + // __cvta_generic_to_shared added in CUDA 11+ + #if __CUDACC_VER_MAJOR__ >= 11 + #define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 + #endif + + // __nvvm_get_smem_pointer added in CUDA 10.2 + #if __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2 + #define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER 1 + #endif +#endif + +#if CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED + #define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED 1 +#endif + +#if !defined(CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED) && CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED && defined(__CUDA_ARCH__) + #define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED 1 +#endif + +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER + #define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED 1 +#endif + +#if !defined(CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED) && CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED && defined(__CUDA_ARCH__) + #define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED 1 +#endif + +// Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need +// to provide one for NVCC +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER + extern "C" { + // This NVVM intrinsic is subject to change in future versions of CUDA. + // Clients should not call it directly. + CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*); + } +#endif + +namespace cute +{ + +/// CUTE helper to cast SMEM pointer to unsigned +CUTE_HOST_DEVICE +uint32_t +cast_smem_ptr_to_uint(void const* const ptr) +{ +// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to +// the previous internal intrinsics if they are available. +#if CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED + // + // This NVVM intrinsic converts an address in shared memory to a plain + // unsigned integer. This is necessary to pass to shared memory instructions + // in inline PTX. + // + // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. + // + //__device__ size_t __cvta_generic_to_shared(void* ptr); + + /// CUTE helper to get SMEM pointer + return static_cast(__cvta_generic_to_shared(ptr)); + +#elif CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED + + return __nvvm_get_smem_pointer(ptr); + +#elif defined(__CUDA_ARCH__) + + uint32_t smem_ptr; + + asm( + "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + : "=r"(smem_ptr) : "l"(ptr)); + + return smem_ptr; + +#else + + + (void) ptr; + printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n"); + return 0; + +#endif +} + +namespace detail { + +// +// Wrapper for MMAOp::fma +// + +template +struct CallFMA { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return MmaOp::fma(static_cast(args)...); + } +}; + +// +// Wrapper for CopyOp::copy +// + +template +struct CallCOPY { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return CopyOp::copy(static_cast(args)...); + } +}; + +// +// Utility for exploding pointers/arrays/tensors into functions +// + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrA&& a, int_sequence) +{ + return fn(a[I]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrS&& s, int_sequence, + PtrD&& d, int_sequence) +{ + return fn(s[Is]..., d[Id]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence) +{ + return fn(a[Ia]..., b[Ib]..., c[Ic]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence, + PtrF&& f, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence, + PtrF&& f, int_sequence, + PtrG&& g, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...); +} + +// +// Utility for exploding tuples into functions +// + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence) +{ + return fn(get(a)...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence, + TupleB&& b, int_sequence) +{ + return fn(get(a)..., get(b)...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence, + TupleB&& b, int_sequence, + TupleC&& c, int_sequence) +{ + return fn(get(a)..., get(b)..., get(c)...); +} + +} // end namespace detail + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_atom.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_atom.hpp new file mode 100644 index 0000000000000000000000000000000000000000..718b342124bbfc110da65d60af3563c60b4378b5 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_atom.hpp @@ -0,0 +1,691 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::__CUTE_REQUIRES +#include // cute::is_tuple +#include // cute::is_constant, cute::is_integral +#include // cute::Copy_Traits +#include // cute::TiledMMA + +namespace cute +{ + +template +struct Copy_Atom; + +template +struct Copy_Atom : Copy_Atom, CopyInternalType> +{}; + +template +struct Copy_Atom, CopyInternalType> + : Copy_Traits +{ + using Traits = Copy_Traits; + + // Bit and Thr layouts from the Copy_Traits + using ThrID = typename Traits::ThrID; + using BitLayoutSrc = typename Traits::SrcLayout; + using BitLayoutDst = typename Traits::DstLayout; + using BitLayoutRef = typename Traits::RefLayout; + + using ValType = CopyInternalType; + + using ValLayoutSrc = decltype(recast_layout(BitLayoutSrc{})); + using ValLayoutDst = decltype(recast_layout(BitLayoutDst{})); + using ValLayoutRef = decltype(recast_layout(BitLayoutRef{})); + + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType."); + + static constexpr int NumValSrc = size<1>(ValLayoutSrc{}); + static constexpr int NumValDst = size<1>(ValLayoutDst{}); + + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + auto traits = Traits::with(static_cast(args)...); + return Copy_Atom{traits}; + } + + // + // Tensor call interfaces + // + + // Check and call instruction, or recurse + template + CUTE_HOST_DEVICE + void + call(Tensor const& src, + Tensor & dst) const + { + static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); + static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); + + if constexpr (is_constant::value || + is_constant::value) { + // Dispatch to unpack to execute instruction + return copy_unpack(static_cast(*this), src, dst); + } else if constexpr (is_tuple::value && + is_tuple::value) { + // If the size of the src/dst doesn't match the instruction, + // recurse this rank-1 layout by peeling off the mode + // ((A,B,C,...)) -> (A,B,C,...) + return copy(*this, tensor<0>(src), tensor<0>(dst)); + } else { + static_assert(dependent_false, + "CopyAtom: Src/Dst partitioning does not match the instruction requirement."); + } + } + + // Accept mutable temporaries + template + CUTE_HOST_DEVICE + void + call(Tensor const& src, + Tensor && dst) const + { + return call(src, dst); + } + + // Check and call instruction, or recurse + template + CUTE_HOST_DEVICE + void + call(Tensor const& prd, + Tensor const& src, + Tensor & dst) const + { + static_assert(PLayout::rank == 1, "Expected rank-1 prd tensor"); + static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); + static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); + + if constexpr (is_constant::value || + is_constant::value) { + // Dispatch to unpack to execute instruction + Traits const& traits = static_cast(*this); + auto has_with_bool = cute::is_valid([](auto t)->void_t{}, traits); + if constexpr (has_with_bool) { + copy_unpack(traits.with(prd(Int<0>{})), src, dst); + } else { + if (prd(Int<0>{})) { copy_unpack(traits, src, dst); } + } + } else if constexpr (is_tuple::value && + is_tuple::value && + is_tuple::value) { + // If the size of the src/dst doesn't match the instruction, + // recurse this rank-1 layout by peeling off the mode + // ((A,B,C,...)) -> (A,B,C,...) + return copy_if(*this, tensor<0>(prd), tensor<0>(src), tensor<0>(dst)); + } else { + static_assert(dependent_false, + "CopyAtom: Src/Dst partitioning does not match the instruction requirement."); + } + } + + // Accept mutable temporaries + template + CUTE_HOST_DEVICE + void + call(Tensor const& prd, + Tensor const& src, + Tensor && dst) const + { + return call(prd, src, dst); + } +}; + +// +// A tiling of copy atoms +// + +template +struct ThrCopy; + +template coord [Need not be 2D...] + class ShapeTiler_MN> // coord space +struct TiledCopy : Copy_Atom +{ + // Layout information from the CopyAtom + using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx + using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset + using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset + using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset + + using AtomNumThr = decltype(size<0>(AtomLayoutRef{})); + using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); + + // Layout information for the TiledCopy + using Tiler_MN = ShapeTiler_MN; + using TiledLayout_TV = LayoutCopy_TV; + using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); + using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); + + CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); + CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) + // where + // Thr: The logical threads within the tiled copy. + // FrgV: The values local to a COPY_ATOM Src. + // FrgX: The values tiled across COPY_ATOMs Src. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_S(STensor&& stensor) + { + CUTE_STATIC_ASSERT_V(rank(stensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); + + // Tile the stensor and compute the (src-thr, src-val) -> (ref-thr, ref-val) layout + return tile2thrfrg(zipped_divide(stensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); + } + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) + // where + // Thr: The logical threads within the tiled copy. + // FrgV: The values local to a COPY_ATOM Dst. + // FrgX: The values tiled across COPY_ATOMs Dst. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_D(DTensor&& dtensor) + { + CUTE_STATIC_ASSERT_V(rank(dtensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); + + // Tile the dtensor and compute the (dst-thr, dst-val) -> (ref-thr, ref-val) layout + return tile2thrfrg(zipped_divide(dtensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); + } + + // Tile a tensor or a layout from shape + // ((TileM,TileN,...), (RestM,RestN,...)) + // to shape + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) + template + CUTE_HOST_DEVICE constexpr static + auto + tile2thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) + { + // Take the thrs/vals that the atom is interested in + // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID + auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); + // ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n) + + // Transform to the trg layout + auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _); + // ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n) + + // Transform the thrs mode from thrid to thr_idx + // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID + auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{}); + // ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n) + + /// ================== + + // Transform the tile mode + auto tv_tensor = tensor.compose(thrval2mn, _); + // ((thrid,val),(RestM,RestN,...)) + + // Unfold and return + return tv_tensor(make_coord(_,_), _); + } + + // retile_S and retile_D assume they are working with the reference layout -- they are the same + template + CUTE_HOST_DEVICE constexpr static + auto + retile(Tensor&& tensor) + { + constexpr int R = remove_cvref_t::rank; + // Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation + + // Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV. + // Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV + // and that shape is what we gather from the other modes of tensor + + auto V = size<0>(tensor); + + auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(shape(Tiler_MN{}))); + // (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV + + auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); + // (atom_vals,rest_vals) -> (v,m,n) + + /// ======= + + // Tile the tensor for TileFrg + auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); + // ((TileV,TileM,TileN,...),(1,RestM,RestN,...)) + + // Transform the tile mode + auto v_tensor = t_tensor.compose(frg_layout_v, _); + // ((atom_vals,rest_vals),(1,RM,RN,...)) + + // Unfold and return + return v_tensor(_, append(Int<0>{},_)); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutS_TV() + { + // (M,N) -> (M,N) + auto ref_S = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); + // (thr_idx,val_idx) -> (M,N) + return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{}); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutD_TV() + { + // (M,N) -> (M,N) + auto ref_D = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); + // (thr_idx,val_idx) -> (M,N) + return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{}); + } + + template ::value)> + CUTE_HOST_DEVICE static + auto + get_slice(ThrIdx const& thr_idx) + { + return ThrCopy(thr_idx); + } + + template ::value)> + CUTE_HOST_DEVICE static + auto + get_thread_slice(ThrIdx const& thr_idx) + { + return get_slice(thr_idx); + } +}; + +template +struct ThrCopy +{ + ThrIdx thr_idx_; + + CUTE_HOST_DEVICE + ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} + + template + CUTE_HOST_DEVICE + auto + partition_S(STensor&& stensor) const { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + auto thr_tensor = make_tensor(static_cast(stensor).data(), TiledCopy::tidfrg_S(stensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE + auto + partition_D(DTensor&& dtensor) const { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + auto thr_tensor = make_tensor(static_cast(dtensor).data(), TiledCopy::tidfrg_D(dtensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE static + auto + retile_S(STensor&& stensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + return make_tensor(static_cast(stensor).data(), TiledCopy::retile(stensor.layout())); + } + + template + CUTE_HOST_DEVICE static + auto + retile_D(DTensor&& dtensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + return make_tensor(static_cast(dtensor).data(), TiledCopy::retile(dtensor.layout())); + } +}; + + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_impl(Copy_Atom const& atom, + LayoutCopy_TV const&, + Tiler const&) +{ + return TiledCopy, LayoutCopy_TV, Tiler>{atom}; +} + +// +// These tile the Copy_Atom as a whole +// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_A(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutA_TV(), make_shape(tile_size<0>(mma),tile_size<2>(mma))); +} + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_B(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutB_TV(), make_shape(tile_size<1>(mma),tile_size<2>(mma))); +} + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutC_TV(), make_shape(tile_size<0>(mma),tile_size<1>(mma))); +} + +// returns the smallest tiled copy that can retile LayoutC_TV +// for use with pipelined epilogues with subtiled stores +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_atom(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + // Truncate the V-layout to just the Copy_Atom, keep the V-order + auto layoutC_TV = mma.get_layoutC_TV(); + auto copy_V = Int::NumValSrc>{}; + CUTE_STATIC_ASSERT_V(copy_V <= size<1>(layoutC_TV)); + auto layout_TV = composition(layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V))); + + // Recompute tiler and restride the TV layout for the new tiler + + // Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them + // Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA + auto mma_tiler = make_shape(tile_size<0>(mma),tile_size<1>(mma)); + auto mma_zeros = repeat_like(mma_tiler, Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(mma_tiler, replace(mma_zeros, Int<1>{})), layout_TV)); + }); + + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // Apply the tiler to a reference and transform the codomain + // tile_coord -> mma_coord + auto tile2mma = composition(make_layout(mma_tiler), tiler); + + // (tid,vid) -> tile_coord + auto layout_tv = composition(left_inverse(tile2mma), layout_TV); + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +/** Produce a TiledCopy from logical thread and values layouts. + * The thread and value layouts map coordinates to thr_idx and val_idx. + * The product of these layouts is taken to produce the TV layout and the Tiler. + * Useful when threads and values need very specific mappings onto coordinates + * in the target tensors. + */ +template > +CUTE_HOST_DEVICE +auto +make_tiled_copy(Copy_Atom const& copy_atom, + ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx + ValLayout const& val_layout = {}) // (m,n) -> val_idx +{ + // Take the raked_products to compute the Layout_MN + // (M,N) -> (thr_idx, val_idx) + auto layout_mn = raked_product(thr_layout, val_layout); + // (thr_idx, val_idx) -> (M,N) + auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); + // Tiler for extracting relevant elements + // (M,N) -> tensor coord + auto tiler = product_each(shape(layout_mn)); + +#if 0 + print("thr_layout: "); print(thr_layout); print("\n"); + print("val_layout: "); print(val_layout); print("\n"); + print("layout_mn : "); print(layout_mn); print("\n"); + print("layout_tv : "); print(layout_tv); print("\n"); + print("tiler : "); print(tiler); print("\n"); +#endif + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +/** Produce a TiledCopy from thread and value offset maps. + * The TV Layout maps threads and values to the codomain of the data_layout. + * It is verified that the intended codomain is valid within data_layout. + * Useful when threads and values don't care about owning specific coordinates, but + * care more about the vector-width and offsets between them. + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_cotiled_copy(Copy_Atom const& copy_atom, + AtomTVLayout const& atom_tv_layout, // atom (thr,val) -> data addr + DataLayout const& data_layout) // coord -> data addr The target layout +{ + static_assert(is_static::value); + static_assert(is_static::value); + + // data addr -> data coord Append 1:0 so off-the-ends get the stride-0 + auto inv_data_layout = make_layout(left_inverse(data_layout), Layout<_1,_0>{}); + + // (tid,vid) -> data_coord + auto layout_tv_data = composition(inv_data_layout, atom_tv_layout); + + // Check validity + // Append 1:0 to data_layout so that OOB coordinates get the stride-0 + CUTE_STATIC_ASSERT_V(coalesce(composition(make_layout(data_layout, Layout<_1,_0>{}), layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), + "The memory pointed to by AtomTVLayout does not exist in the DataLayout."); + // + // Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them + // + + // Convert to the awkward by-mode tiler to preserve the modes of the tiled DATA + auto flat_data_shape = product_each(shape(data_layout)); + auto flat_data_zeros = repeat(Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(flat_data_shape, replace(flat_data_zeros, Int<1>{})), layout_tv_data)); + }); + + // + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // + + // Apply the tiler to a reference and transform the codomain + // tile_coord -> data_coord + auto tile2data = composition(make_layout(flat_data_shape), tiler); + + // (tid,vid) -> tile_coord + auto layout_tv = composition(left_inverse(tile2data), layout_tv_data); + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_S(Copy_Atom const& copy_atom, + TiledCopy const& tiled_copy) +{ + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{}); +} + +// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_D(Copy_Atom const& copy_atom, + TiledCopy const& tiled_copy) +{ + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{}); +} + +// +// Size +// + +// The logical size of a TileCopy +template +CUTE_HOST_DEVICE constexpr +auto +tile_size(TiledCopy const&) +{ + return size(typename TiledCopy::Tiler_MN{}); +} + +// The number of threads involved in a TiledCopy +template +CUTE_HOST_DEVICE constexpr +auto +size(TiledCopy const&) +{ + return typename TiledCopy::TiledNumThr{}; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE +void +print(Copy_Atom, T> const&) +{ + using Atom = Copy_Atom, T>; + print("Copy_Atom\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n"); + print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n"); + print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n"); + print(" ValueType: "); print(sizeof_bits::value); print("b\n"); +} + +template +CUTE_HOST_DEVICE +void +print(TiledCopy const& copy, char const* pad = "") +{ + using Copy = TiledCopy; + print("TiledCopy\n"); + print(" Tiler_MN: "); print(typename Copy::Tiler_MN{}); print("\n"); + print(" TiledLayout_TV: "); print(typename Copy::TiledLayout_TV{}); print("\n"); + print(static_cast(copy)); +} + +template +CUTE_HOST_DEVICE +void +print(ThrCopy const& thr_copy) +{ + print("ThrCopy\n"); + print(" ThrIdx: "); print(thr_copy.thr_idx_); print("\n"); + print(TiledCopy{}); +} + +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + + +// Config +#if (__CUDACC_VER_MAJOR__ >= 12) +# define CUTE_COPY_ATOM_TMA_SM90_ENABLED +# define CUTE_COPY_ATOM_TMA_SM100_ENABLED +#endif + + +#if (!defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)) +# define CUTE_COPY_ATOM_TMA_SM90_ENABLED +#endif + +#if (!defined(CUTE_COPY_ATOM_TMA_SM100_ENABLED)) +# define CUTE_COPY_ATOM_TMA_SM100_ENABLED +#endif + + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +#include +#endif + + +#if defined(CUTE_COPY_ATOM_TMA_SM100_ENABLED) +#include +#endif + + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9117a1fb13de70e8a8ea74c8401e63e55164eef5 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +/** + * concept Copy_Traits + * { + * using ThrID = // Logical thread id (tid) -> tidx + * + * using SrcLayout = // (Logical src thread id (tid), Logical src value id (vid)) -> bit + * using DstLayout = // (Logical dst thread id (tid), Logical dst value id (vid)) -> bit + * using RefLayout = // (Logical ref thread id (tid), Logical ref value id (vid)) -> bit + * }; + * + * The abstract bit ordering of the Copy_Traits (the codomain of SrcLayout, DstLayout, and RefLayout) + * is arbitrary and only used to construct maps + * (ref-tid,ref-vid) -> (src-tid,src-vid) + * (ref-tid,ref-vid) -> (dst-tid,dst-vid) + * in TiledCopy. The Layout_TV in TiledCopy is in accordance with the RefLayout of a Traits, then mapped to + * the Src or Dst (tid,vid) representation on demand. + * + */ + +template +struct Copy_Traits +{ + static_assert(dependent_false, "Copy_Traits not implemented for this CopyOperation."); +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0,_0>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0,_0>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +// Extract a CPY_Op from a CPY_Traits +template +struct CPY_Op {}; + +template +struct CPY_Op> { + using type = CPY_Op_Arg; +}; + +// +// Generic copy_unpack for common argument-based Copy_Traits +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(AnyCPYTraits const&, + Tensor const& src, + Tensor & dst) +{ + using CopyOp = typename CPY_Op::type; + using RegistersSrc = typename CopyOp::SRegisters; + using RegistersDst = typename CopyOp::DRegisters; + using RegTypeSrc = typename remove_extent::type; + using RegTypeDst = typename remove_extent::type; + constexpr int RegNumSrc = extent::value; + constexpr int RegNumDst = extent::value; + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "Copy_Traits: src failed to vectorize into registers. Layout is incompatible with this CopyOp."); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "Copy_Traits: dst failed to vectorize into registers. Layout is incompatible with this CopyOp."); + + detail::explode(detail::CallCOPY{}, + rS, make_int_sequence{}, + rD, make_int_sequence{}); +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(AnyCPYTraits const& traits, + Tensor const& src, + Tensor && dst) +{ + copy_unpack(traits, src, dst); +} + +namespace detail { + +template +constexpr bool is_prefetch = false; + +template +constexpr bool is_prefetch> = is_same_v; + +} // end namespace detail + + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..216d739e47dbeb95a2be1bd88e5ac5e1ded2be29 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100.hpp @@ -0,0 +1,3798 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include + +namespace cute +{ +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_8, _2, _2, _2>>, + Stride,Stride<_1,_128,_64,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_8, _2, _2, _4>>, + Stride,Stride<_1,_128,_64,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_8, _2, _2>>, + Stride,Stride<_1,_128,_64>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_128>, + Stride, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_8, _2, _2, _2>>, + Stride,Stride<_1,_128,_64,_1024>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_128>, + Stride, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_8, _2, _2, _4>>, + Stride,Stride<_1,_128,_64,_1024>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_128, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM Traits and Utilities +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Copy_Atom; + +/** Generate a TiledCopy from a CopyAtom and a TMEM tensor + * Example: + * Tensor gmem_tensor = ... // (M,N,...) + * Tensor tmem_tensor = ... // (M,N,...) + * auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD_Operation, tmem_tensor); + * auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + * + * Tensor tDtC = thr_tmem_load.partition_S(tmem_tensor); // (TMEM_LOAD,TMEM_LOAD_M,TMEM_LOAD_N,...) + * Tensor tDgC = thr_tmem_load.partition_D(gmem_tensor); // (TMEM_LOAD,TMEM_LOAD_M,TMEM_LOAD_N,...) + * Tensor tDrC = make_tensor(shape(tDgD)); // (TMEM_LOAD,TMEM_LOAD_M,TMEM_LOAD_N,...) + * + * copy(tiled_tmem_load, tDtC, tDrC); // tmem -> rmem + * copy(tDrC, tDgC); // rmem -> gmem + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_tmem_copy(Copy_Atom const& atom, + Tensor const& tmem) +{ + static_assert(is_tmem::value, "Expected TMEM tensor."); + using T = typename TEngine::value_type; + using Traits = typename Copy_Atom::Traits; + static_assert(sizeof_bits_v == sizeof_bits_v, + "Expected a CopyAtom with the same type-width as the Tensor."); + + // atom thr idx -> tmem addr 4warps where each warp points to the same position within it's own subpartition + auto atom_t_layout = Layout, Stride<_0, decltype(Int<32>{} * TMEM::DP{})>>{}; + // atom val idx -> tmem addr Cast the CopyOp's value ids to the proper data width + auto atom_v_layout = coalesce(upcast::value>(typename Traits::ValID{})); + + return make_cotiled_copy(atom, make_layout(atom_t_layout, atom_v_layout), tmem.layout()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_tmem_copy(CopyOp const&, + Tensor const& tmem) +{ + return make_tmem_copy(Copy_Atom{}, tmem); +} + +/** Generate a TV_Tiler from a TMEM tensor + * Example: + * Tensor gmem_tensor = ... // (M,N,...) + * Tensor tmem_tensor = ... // (M,N,...) + * auto tmem_tiler = make_tmem_warp_partitioner(tmem_tensor); + * auto warp_tiler = tmem_tiler.get_slice(warp_idx); + * + * Tensor tWtC = warp_tiler.partition(tmem_tensor); // (WARP_M,WARP_N,...) + * Tensor tWgC = warp_tiler.partition(gmem_tensor); // (WARP_M,WARP_N,...) + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_tmem_warp_partitioner(Tensor const& tmem) +{ + static_assert(is_tmem::value, "Expected TMEM tensor."); + using T = typename TEngine::value_type; + + // warp idx -> tmem addr This is the T in the Layout_TV + auto atom_t_layout = Layout<_4, decltype(Int<32>{} * TMEM::DP{})>{}; + + // tmem coord -> tmem addr + auto tmem_layout = tmem.layout(); + // tmem addr -> tmem coord Append 1:0 so off-the-ends get the stride-0 + auto inv_tmem_layout = make_layout(left_inverse(tmem_layout), Layout<_1,_0>{}); + + // wid -> tmem_coord + auto layout_t_tmem = composition(inv_tmem_layout, atom_t_layout); + // + // Tiler -- Find the active elements in the TMEM tensor and generate a tiler to extract them + // + + // Convert to the awkward by-mode tiler to preserve the modes of the tiled TMEM + auto flat_tmem_shape = product_each(shape(tmem_layout)); + auto flat_tmem_zeros = repeat(Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(flat_tmem_shape, replace(flat_tmem_zeros, Int<1>{})), layout_t_tmem)); + }); + + // + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // + + // Apply the tiler to a reference and transform the codomain + // tile_coord -> tmem_coord + auto tile2tmem = composition(make_layout(flat_tmem_shape), tiler); + + // wid -> tile_coord + auto layout_tv = composition(left_inverse(tile2tmem), layout_t_tmem); + return make_tiler_impl(layout_tv, tiler); +} + +namespace SM100::TMEM::LOAD { + +// +// Specialized copy_unpack implementation for SM100::TMEM::LOAD instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_tmem::value, "Expected TMEM src."); + static_assert(is_rmem::value, "Expected RMEM dst."); + + using SrcType = typename TS::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(src)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected src to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(src.data()); + + using RegTypeDst = typename remove_extent::type; + Tensor rD = recast(dst); + + constexpr int RegNumDst = extent::value; + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this CopyOp."); + + // thread idx <=> DP lane assert. + // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + + detail::explode(CopyOp::copy, + &tmem_addr, seq<0>{}, + rD, make_seq{}); +} + +} // end namespace SM100::TMEM::LOAD + +namespace SM100::TMEM::STORE { + +// +// Specialized copy_unpack implementation for SM100::TMEM::STORE instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_rmem::value, "Expected RMEM src."); + static_assert(is_tmem::value, "Expected TMEM dst."); + + using RegTypeSrc = typename remove_extent::type; + Tensor rS = recast(src); + + constexpr int RegNumSrc = extent::value; + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + + using DstType = typename TD::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(dst)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected dst to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(dst.data()); + + // thread idx <=> DP lane assert. + // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + + detail::explode(CopyOp::copy, + rS, make_seq{}, + &tmem_addr, seq<0>{}); +} + +} // end namespace SM100::TMEM::STORE + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_LOAD Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b1x; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + // Logical bit id to bit idx (address) + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_2048>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_2048>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _2>>, + Stride,Stride< _1,_4096,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _2>>, + Stride,Stride< _1,_4096,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _4>>, + Stride,Stride< _1,_8192,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _4>>, + Stride,Stride< _1,_8192,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _8>>, + Stride,Stride< _1,_16384,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _8>>, + Stride,Stride< _1,_16384,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _16>>, + Stride,Stride< _1,_32768,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _16>>, + Stride,Stride< _1,_32768,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _32>>, + Stride,Stride< _1,_65536,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _32>>, + Stride,Stride< _1,_65536,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_1024>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_1024>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _2>>, + Stride,Stride< _1,_2048,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _2>>, + Stride,Stride< _1,_2048,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _4>>, + Stride,Stride< _1,_4096,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _4>>, + Stride,Stride< _1,_4096,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _8>>, + Stride,Stride< _1,_8192,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _8>>, + Stride,Stride< _1,_8192,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _16>>, + Stride,Stride< _1,_16384,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _16>>, + Stride,Stride< _1,_16384,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _32>>, + Stride,Stride< _1,_32768,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _32>>, + Stride,Stride< _1,_32768,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _64>>, + Stride,Stride< _1,_65536,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _64>>, + Stride,Stride< _1,_65536,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _4>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _4>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _8>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _8>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_16>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_16>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_32>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_32>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_64>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_64>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b128x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_128>>, + Stride,Stride< _1, _64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_128>>, + Stride,Stride< _1, _64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_64>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_64>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_128>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_128>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_256>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_256>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_512>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_512>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_1024>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_1024>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_2048>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_2048>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_4096>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_4096>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_32, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_32, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_64, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_64, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_128, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_128, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_256, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_256, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_512, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_512, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_1024, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_1024, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_2048, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_2048, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_4096, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_4096, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_STORE Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b128x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace TMEM { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Given a 1x tmem copy op, returns the widest repeated variant that divides the specified bits in the N-mode +template +CUTE_HOST_DEVICE constexpr +auto +op_repeater() +{ + if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_LOAD_16dp256b32x{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_LOAD_16dp256b16x{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_LOAD_16dp256b8x{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_LOAD_16dp256b4x{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_LOAD_16dp256b2x{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_LOAD_16dp256b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_LOAD_16dp256b32x_16b{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_LOAD_16dp256b16x_16b{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_LOAD_16dp256b8x_16b{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_LOAD_16dp256b4x_16b{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_LOAD_16dp256b2x_16b{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_LOAD_16dp256b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_LOAD_16dp128b64x{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_LOAD_16dp128b32x{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_LOAD_16dp128b16x{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_LOAD_16dp128b8x{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_LOAD_16dp128b4x{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_LOAD_16dp128b2x{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_LOAD_16dp128b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_LOAD_16dp128b64x_16b{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_LOAD_16dp128b32x_16b{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_LOAD_16dp128b16x_16b{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_LOAD_16dp128b8x_16b{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_LOAD_16dp128b4x_16b{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_LOAD_16dp128b2x_16b{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_LOAD_16dp128b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp64b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp64b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp64b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp64b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp64b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp64b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp64b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp64b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp64b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp64b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp64b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp64b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp64b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp64b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp64b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp64b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp32b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp32b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp32b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp32b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp32b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp32b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp32b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp32b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp32b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp32b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp32b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp32b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp32b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp32b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_LOAD_32dp32b128x{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_LOAD_32dp32b64x{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_LOAD_32dp32b32x{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_LOAD_32dp32b16x{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_LOAD_32dp32b8x{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_LOAD_32dp32b4x{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_LOAD_32dp32b2x{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_LOAD_32dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_LOAD_32dp32b128x_16b{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_LOAD_32dp32b64x_16b{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_LOAD_32dp32b32x_16b{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_LOAD_32dp32b16x_16b{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_LOAD_32dp32b8x_16b{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_LOAD_32dp32b4x_16b{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_LOAD_32dp32b2x_16b{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_LOAD_32dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_STORE_16dp256b32x{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_STORE_16dp256b16x{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_STORE_16dp256b8x{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_STORE_16dp256b4x{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_STORE_16dp256b2x{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_STORE_16dp256b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_STORE_16dp256b32x_16b{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_STORE_16dp256b16x_16b{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_STORE_16dp256b8x_16b{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_STORE_16dp256b4x_16b{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_STORE_16dp256b2x_16b{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_STORE_16dp256b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_STORE_16dp128b64x{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_STORE_16dp128b32x{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_STORE_16dp128b16x{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_STORE_16dp128b8x{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_STORE_16dp128b4x{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_STORE_16dp128b2x{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_STORE_16dp128b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_STORE_16dp128b64x_16b{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_STORE_16dp128b32x_16b{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_STORE_16dp128b16x_16b{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_STORE_16dp128b8x_16b{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_STORE_16dp128b4x_16b{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_STORE_16dp128b2x_16b{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_STORE_16dp128b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp64b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp64b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp64b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp64b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp64b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp64b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp64b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp64b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp64b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp64b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp64b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp64b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp64b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp64b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp64b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp64b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp32b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp32b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp32b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp32b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp32b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp32b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp32b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp32b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp32b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp32b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp32b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp32b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp32b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp32b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_STORE_32dp32b128x{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_STORE_32dp32b64x{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_STORE_32dp32b32x{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_STORE_32dp32b16x{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_STORE_32dp32b8x{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_STORE_32dp32b2x{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_STORE_32dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_STORE_32dp32b128x_16b{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_STORE_32dp32b64x_16b{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_STORE_32dp32b32x_16b{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_STORE_32dp32b16x_16b{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_STORE_32dp32b8x_16b{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_STORE_32dp32b4x_16b{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_STORE_32dp32b2x_16b{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_STORE_32dp32b1x_16b{}; + } + } + else { + static_assert(dependent_false, "Must pass 1x tmem copy operator"); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Select TMEM store corresponding to the provided TMEM load +template +CUTE_HOST_DEVICE constexpr auto +tmem_load_to_store(CopyOp) { + if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b128x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b128x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b128x_16b{}; + } + else { + static_assert(dependent_false, "No TMEM_STORE matching for provided TMEM_LOAD"); + } +} + +} // namespace TMEM + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// UTCCP Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::UTCCP { + +// +// Specialized copy_unpack implementation for SM100::TMEM::UTCCP instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const&, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + CopyOp::copy(src[0], raw_pointer_cast(dst.data())); +} + +} // end namespace SM100::TMEM::UTCCP + +// In the following UTCCP traits, the ValID is representing: +// logical_bit_idx -> tmem_addr_offset. +// And the logical_bit_idx is numbered in the order of: +// [core_matrix_strided, core_matrix_leading, broadcast, repeat]. +// The first two modes provide convenience for smem_desc construtction. +// The last two modes provide boradcast transformation for 4x32DP and 2x64DP. +// With above, the strides of first two modes are neccessary to be TMEM::DP_b and 1. +// And the stride of the third mode in the SrcLayout must be zero. + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp256bit_1cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_1>; + using ValID = Layout, + Stride>; + using SrcLayout = Layout, + Stride<_0, _1>>; + using DstLayout = Layout, + Stride<_0,_1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp256bit_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = Layout, + Stride<_0, _1>>; + using DstLayout = Layout, + Stride<_0, _1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp128bit_1cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_1>; + using ValID = Layout, + Stride>; + using SrcLayout = Layout, + Stride<_0, _1>>; + using DstLayout = Layout, + Stride<_0,_1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp128bit_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = Layout, + Stride<_0, _1>>; + using DstLayout = Layout, + Stride<_0, _1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4dp256bit_1cta; + +template <> +struct Copy_Traits +{ + /* + 4DP is really hard to model if we consider this instruction as a "copy" instruction. + But, if we take it as "TMEM refresh" instruction, then everything goes out naturally. + 4DP utccp is designed to refresh the last 4 lanes of each tmem subpartition. + So, in the kernel implementation, we usually only don't need to iterate on MMA_M dimension, + but only need to iterate on MMA_K dimension. + And in each refresh, logically we are refreshing MMA's 128 rows M + 256bit K. + So the "atom_v" should be (refresh_m, refresh_k) instead of (copy_m, copy_k). + And the Src/DstLayout below is: copy_bits -> logical_refresh_bits. + */ + + using ThrID = Layout<_1>; + using ValID = Layout, + Stride>; + using SrcLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + using DstLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4dp256bit_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + using DstLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4x32dp128bit_1cta; + +template <> +struct Copy_Traits +{ + using _DP = TMEM::DP_b; + using _DPx32 = Int<_DP{}*32>; + + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + // [core_matrix_strided, core_matrix_leading, broadcast] + using ValID = Layout, + Stride<_DP,_1, _DPx32>>; + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32, _0>>>; + using DstLayout = Layout, + Stride<_0,_1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4x32dp128bit_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32, _0>>>; + using DstLayout = Layout, + Stride<_0,_1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0213_1cta; + +template <> +struct Copy_Traits +{ + using _DP = TMEM::DP_b; + using _DPx64 = Int<_DP{}*64>; + + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + // [core_matrix_strided, core_matrix_leading, broadcast] + using ValID = Layout, + Stride<_DP,_1, _DPx64>>; + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _64, _0>>>; + using DstLayout = Layout, + Stride<_0, _1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0213_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _64, _0>>>; + using DstLayout = Layout, + Stride<_0, _1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0123_1cta; + +template <> +struct Copy_Traits +{ + using _DP = TMEM::DP_b; + using _DPx32 = Int<_DP{}*32>; + using _DPx64 = Int<_DP{}*64>; + + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + // [core_matrix_strided, core_matrix_leading, repeat, broadcast] + using ValID = Layout, + Stride<_DP,_1 ,_DPx64,_DPx32>>; + + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32,_4096,_0>>>; + using DstLayout = Layout, + Stride<_0, _1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0123_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32, _4096,_0>>>; + using DstLayout = Layout, + Stride<_0,_1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr +auto +make_utccp_copy(CopyOp const&, + Tensor const& tmem) +{ + static_assert(is_tmem::value, "Expected TMEM tensor."); + using T = typename TEngine::value_type; + using Traits = Copy_Traits; + using Atom = Copy_Atom; + + // atom thr idx -> tmem addr This is the T in the Layout_TV + auto atom_t_layout = make_layout(size(typename Traits::ThrID{}), Int<0>{}); + // atom val idx -> tmem addr Cast the CopyOp's value ids to the proper data width + auto atom_v_layout = coalesce(upcast::value>(typename Traits::ValID{})); + + return make_cotiled_copy(Atom{}, make_layout(atom_t_layout, atom_v_layout), tmem.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_im2col.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_im2col.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd3bf98bbe4c4bff03d1ba1ad33111e12edf3db9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_im2col.hpp @@ -0,0 +1,488 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief im2col make_tma_copy + +*/ + +#include "cute/arch/copy_sm90.hpp" +#include "cute/arch/copy_sm90_desc.hpp" +#include "cute/atom/copy_traits_sm90_im2col.hpp" +#include "cute/tensor.hpp" + +namespace cute { + +struct SM100_TMA_2SM_LOAD_IM2COL_OP : SM100_TMA_2SM_LOAD_IM2COL {}; + +/// @brief Non-executable specialization of Copy_Traits for SM100 +/// im2col TMA load, with TMA descriptor but no barrier. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const + { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL are not + /// directly executable. Instead, call this "with" member function + /// to get an executable specialization. "Executable" means that + /// @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (unused; only exists + /// for consistency with the actual multicast Copy_Traits + /// specialization) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const + { + return {{}, {&tma_desc_, &tma_mbar}}; + } + + // Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL + // are not directly executable. Instead, call .with + // to get an executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// TMA load, with TMA descriptor and barrier. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_IM2COL arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t* // smem mbarrier + > const opargs_; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_OP : SM100_TMA_2SM_LOAD_IM2COL_MULTICAST {}; + +/// @brief Non-executable specialization of Copy_Traits for SM100 +/// im2col TMA load, with TMA descriptor but no barrier or multicast +/// mask. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const + { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL_MULTICAST + /// are not directly executable. Instead, call this "with" member + /// function to get an executable specialization. "Executable" + /// means that @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (defaults to a single CTA) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, uint16_t const& multicast_mask) const + { + return {{}, {&tma_desc_, &tma_mbar, multicast_mask}}; + } + + // Copy_Traits specializations with SM100_TMA_LOAD_IM2COL_MULTICAST + // are not directly executable. Instead, call .with to get an + // executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM100 multicast +/// im2col TMA load, with TMA descriptor, barrier, and multicast mask. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit. + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_IM2COL_MULTICAST arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t // multicast mask + > const opargs_; +}; + +//////////////////////////////////// +// Make TMA +/////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM100_TMA_2SM_LOAD + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cluster_tile The Cluster-local tile that each Cluster will be tiling GMEM with. + * This is often the cluster_tile_shape that is used to tile the GMEM: + * local_tile(gtensor, cluster_tile_shape, cluster_coord) + * -> Cluster-local tile of GMEM + * @param mma The TiledMMA that defines the Cluster-Tile to Block-Tile partitioning. + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reordering the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + */ +template +CUTE_HOST +auto +make_im2col_tma_copy_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M,K,...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K) + Cluster_Tile const& cluster_tile, // (TILE_M,TILE_N,TILE_K) + TiledMMA const& mma, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only MK modes from MNK + auto cluster_tile_shape = append(make_shape(get<0>(cluster_tile), get<2>(cluster_tile)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(cluster_layout))(_, repeat(_)); + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_1,_0>{}; // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides)); + + return detail::make_tma_copy_im2col(copy_op, gtensor, slayout, + cta_t_map, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, + lower_srt, stride_srt, aux_params); +} + +template +CUTE_HOST +auto +make_im2col_tma_copy_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N,K,...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K) + Cluster_Tile const& cluster_tile, // (TILE_M,TILE_N,TILE_K) + TiledMMA const& mma, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only NK modes from MNK + auto cluster_tile_shape = append(make_shape(get<1>(cluster_tile), get<2>(cluster_tile)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(cluster_layout))(_, repeat(_)); + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_1,_0,_0>{}; // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides)); + + return detail::make_tma_copy_im2col(copy_op, gtensor, slayout, + cta_t_map, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, + lower_srt, stride_srt, aux_params); +} + +///////////////////////////////////// +// Experimental Make Im2col TMA Atom +///////////////////////////////////// + +template +CUTE_HOST +auto +make_im2col_tma_atom_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, K, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape, // (CTA_V, CTA_M, CTA_N, CTA_K) + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only MK modes from MNK + auto cluster_tile_shape = append(make_shape(get<0>(mma_tiler), get<2>(mma_tiler)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(cluster_layout))(_, repeat(_)); + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<2>(cluster_shape); // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + return detail::make_tma_atom_im2col(copy_op, gtensor, slayout, num_multicast, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, + stride_whd, lower_srt, stride_srt, aux_params); +} + +template +CUTE_HOST +auto +make_im2col_tma_atom_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N, K, ...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape, // (CTA_V, CTA_M, CTA_N, CTA_K) + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only NK modes from MNK + auto cluster_tile_shape = append(make_shape(get<1>(mma_tiler), get<2>(mma_tiler)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(cluster_layout))(_, repeat(_)); + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<1>(cluster_shape); // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + return detail::make_tma_atom_im2col(copy_op, gtensor, slayout, num_multicast, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, + stride_whd, lower_srt, stride_srt, aux_params); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_tma.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8621e8a98239184aaa44a9881d6f4548bf840222 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_tma.hpp @@ -0,0 +1,501 @@ +/*************************************************************************************************** + * Copyright (c) 2021 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include +#include +#include + +namespace cute +{ + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////// TMA_LOAD //////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_OP : SM100_TMA_2SM_LOAD {}; + +// The non-executable SM100_TMA_2SM_LOAD with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM100_TMA_2SM_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {&tma_desc_, &tma_mbar, static_cast(cache_hint)}}; + } + + // Construct an executable SM100_TMA_2SM_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {new_tma_desc, &tma_mbar, static_cast(cache_hint)}}; + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM100_TMA_2SM_LOAD before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM100_TMA_2SM_LOAD with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint64_t // cache hint + > const opargs_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_MULTICAST_OP : SM100_TMA_2SM_LOAD_MULTICAST {}; + +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_MULTICAST_OP arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM100_TMA_2SM_LOAD_MULTICAST_OP with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; + } + + // Construct an executable SM100_TMA_2SM_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM100_TMA_2SM_LOAD_MULTICAST_OP before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_MULTICAST_OP arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t, // multicast mask + uint64_t // cache hint + > const opargs_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } +}; + +//////////////////////////////////// +// Make TMA +/////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM100_TMA_2SM_LOAD + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cluster_tile The Cluster-local tile that each Cluster will be tiling GMEM with. + * This is often the cluster_tile_shape that is used to tile the GMEM: + * local_tile(gtensor, cluster_tile_shape, cluster_coord) + * -> Cluster-local tile of GMEM + * @param mma The TiledMMA that defines the Cluster-Tile to Block-Tile partitioning. + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reordering the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + */ +template +CUTE_HOST +auto +make_tma_copy_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, K, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...) + Cluster_Tiler const& cluster_tiler, // (TILER_M, TILER_N, TILER_K, ...) + TiledMMA const& mma) +{ + // Keep only MK modes from MNK + auto cluster_tiler_mk = remove<1>(cluster_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_mk); // (TILE_M, TILE_K, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(g_tile))(_, repeat(_)); // (MMA, MMA_M, MMA_K, ...) + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_1,_0>{}; // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = coalesce(make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides))); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); +} + +template +CUTE_HOST +auto +make_tma_copy_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N, K, ...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...) + Cluster_Tiler const& cluster_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma) +{ + // Keep only NK modes from MNK + auto cluster_tiler_nk = remove<0>(cluster_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_nk); // (TILE_N, TILE_K, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(g_tile))(_, repeat(_)); // (MMA, MMA_N, MMA_K, ...) + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_1,_0,_0>{}; // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = coalesce(make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides))); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); +} + +template +CUTE_HOST +auto +make_tma_copy_C_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, N, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_N, ...) + Cluster_Tiler const& cluster_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma) +{ + // Keep only MN modes from MNK + auto cluster_tiler_mn = remove<2>(cluster_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_mn); // (TILE_M, TILE_N, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_C(g_tile))(_, repeat(_)); // (MMA, MMA_M, MMA_N, ...) + + static_assert(is_same_v || + is_same_v || + is_same_v, + "Unsupported TMA Op, expected a non-multicast TMA"); + + // No multicast, so only 1 CTA involved + auto cta_t_map = Layout<_1,_0>{}; + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); +} + +//////////////////////////////////// +// Experimental Make TMA Atom +/////////////////////////////////// + +template +CUTE_HOST +auto +make_tma_atom_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, K, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape) // (CTA_V, CTA_M, CTA_N, CTA_K) +{ + // Keep only MK modes from MNK + auto mma_tiler_mk = remove<1>(mma_tiler); + + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(mma_tiler_mk); // (TILE_M, TILE_K, ...) + + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(g_tile))(_, repeat(_)); // (MMA, MMA_M, MMA_K, ...) + +#if 0 + print("(tma_a) slayout: "); print(slayout); print("\n"); + print("(tma_a) mma_tiler_nk: "); print(mma_tiler_nk); print("\n"); + print("(tma_a) g_tile: "); print(g_tile); print("\n"); + print("(tma_a) mma_tiler: "); print(mma_tiler); print("\n"); + print("(tma_a) cta_v_tile: "); print(cta_v_tile); print("\n"); +#endif + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<2>(cluster_shape); // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, gtensor, slayout, num_multicast, cta_v_tile); +} + +template +CUTE_HOST +auto +make_tma_atom_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N, K, ...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape) // (CTA_V, CTA_M, CTA_N, CTA_K) +{ + // Keep only NK modes from MNK + auto mma_tiler_nk = remove<0>(mma_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(mma_tiler_nk); // (TILE_N, TILE_K, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(g_tile))(_, repeat(_)); // (MMA, MMA_N, MMA_K, ...) + +#if 0 + print("(tma_b) slayout: "); print(slayout); print("\n"); + print("(tma_b) mma_tiler_nk: "); print(mma_tiler_nk); print("\n"); + print("(tma_b) g_tile: "); print(g_tile); print("\n"); + print("(tma_b) mma_tiler: "); print(mma_tiler); print("\n"); + print("(tma_b) cta_v_tile: "); print(cta_v_tile); print("\n"); +#endif + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<1>(cluster_shape); // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, gtensor, slayout, num_multicast, cta_v_tile); +} + +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm50.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm50.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5299894bfcb8e7d8734bf98a497e67e3abf6dca6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm50.hpp @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_64, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1, _64>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_64, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_32, _2>>, + Stride,Stride< _1, _256>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm75.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm75.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c43b7483ba28d0f250a9e317f805af14e1dc90f6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm75.hpp @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2>>, + Stride,Stride< _1,_128>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2, _2>>, + Stride,Stride< _1,_128,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2, _4>>, + Stride,Stride< _1,_128,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_32, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm80.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm80.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ab8d128c291ec79fc5ccc5f6aab721b647aea150 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm80.hpp @@ -0,0 +1,167 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Predicate value: true = load, false = zfill + bool pred = true; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } + + // Overload copy_unpack for zfill variant to pass the predicate into the op + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem source for cp.async."); + static_assert(is_smem::value, "Expected smem destination for cp.async."); + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int<1>{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int<1>{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + SM80_CP_ASYNC_CACHEALWAYS_ZFILL::copy(rS[0], rD[0], traits.pred); + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Predicate value: true = load, false = zfill + bool pred = true; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } + + // Overload copy_unpack for zfill variant to pass the predicate into the op + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem source for cp.async."); + static_assert(is_smem::value, "Expected smem destination for cp.async."); + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int<1>{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int<1>{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + SM80_CP_ASYNC_CACHEGLOBAL_ZFILL::copy(rS[0], rD[0], traits.pred); + } +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ad479df0606a3d23fbb7e125c052ece203282803 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_im2col.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_im2col.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e4d1e3ffff293e0e6c1fdbf286c743f5d60542d0 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -0,0 +1,940 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief im2col make_tma_copy +*/ + +#include "cute/arch/copy_sm90.hpp" +#include "cute/arch/copy_sm90_desc.hpp" +#include "cute/tensor.hpp" + +#include "cute/algorithm/prefetch.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/cuda_host_adapter.hpp" + +namespace cute +{ + +// Utility for unpacking TMA_LOAD_IM2COL arguments into a CopyOp +template +struct TMA_LOAD_IM2COL_Unpack +{ + /// Copy from src to dst. + /// + /// @param traits Copy traits created with a TMA descriptor that + /// correctly matches the input tensor and other convolution + /// parameters. + /// + /// @param src Tile of the im2col-transformed coordinate tensor + /// (result of get_tma_tensor), representing the global-memory + /// tensor from which to load. + /// + /// @param dst Shared memory tile, into which to load. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, // tile of the transformed global activation (A) tensor + Tensor & dst) // shared memory tile + { + auto src_coord_offset = src(Int<0>{}); + auto src_coord_cwhdn_offset_srt = flatten(src_coord_offset); + // Interpret the TMA IM2COL coordinate as (c, ([w,h,d]), n, ([s,r,t])) + CUTE_STATIC_ASSERT_V(rank(src_coord_offset) == _4{}); + CUTE_STATIC_ASSERT_V(rank<1>(src_coord_offset) == rank<3>(src_coord_offset)); + + if constexpr (detail::is_prefetch) { + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); + } else { + static_assert(is_smem::value, "SM90_TMA_LOAD_IM2COL requires the destination be shared memory."); + void* dst_ptr = cute::raw_pointer_cast(dst.data()); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); + } + } +}; + +// Copy_Traits for SM90 im2col TMA load comes in two layers. +// +// 1. Copy_Traits +// 2. Copy_Traits +// +// Copy_Traits +// is the "outer" layer. It has a TMA descriptor, +// but no barrier ("tma_mbar"), so it's "nonexecutable." +// One calls its "with" member function with a barrier, +// to get an executable "inner"-layer +// Copy_Traits object. +// That object's "copy_unpack" member function +// actually invokes im2col TMA load. + +struct SM90_TMA_LOAD_IM2COL_OP : SM90_TMA_LOAD_IM2COL {}; + +/// @brief Non-executable specialization of Copy_Traits for SM90 +/// im2col TMA load, with TMA descriptor but no barrier. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const + { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL are not + /// directly executable. Instead, call this "with" member function + /// to get an executable specialization. "Executable" means that + /// @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (unused; only exists + /// for interface compatibility with the actual multicast Copy_Traits) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const + { + return {{}, {&tma_desc_, &tma_mbar}}; + } + + // Copy_Traits specializations with SM90_TMA_LOAD_IM2COL + // are not directly executable. Instead, call .with + // to get an executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM90 im2col +/// TMA load, with TMA descriptor and barrier. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t* // smem mbarrier + > const opargs_; +}; + +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL::PREFETCH arguments + tuple const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) + : opargs_({&traits.tma_desc_}) {} +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_OP : SM90_TMA_LOAD_IM2COL_MULTICAST {}; + +/// @brief Non-executable specialization of Copy_Traits for SM90 +/// im2col TMA load, with TMA descriptor but no barrier or multicast +/// mask. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL_MULTICAST + /// are not directly executable. Instead, call this "with" member + /// function to get an executable specialization. "Executable" + /// means that @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (defaults to a single CTA) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, uint16_t const& multicast_mask) const { + return {{}, {&tma_desc_, &tma_mbar, multicast_mask}}; + } + + // Copy_Traits specializations with SM90_TMA_LOAD_IM2COL_MULTICAST + // are not directly executable. Instead, call .with to get an + // executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM90 multicast +/// im2col TMA load, with TMA descriptor, barrier, and multicast mask. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit. + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL_MULTICAST arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t // multicast mask + > const opargs_; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_STORE IM2COL//////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// The executable SM90_TMA_STORE_IM2COL with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE_IM2COL arguments + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + // This is the copy_unpack dispatch for this Copy_Traits + // Src needs to be a smem tensor + // Dst needs to be a gmem tensor with TmaCoordIterator .data() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE_IM2COL"); + + void const* const desc_ptr = &(traits.tma_desc_); + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = flatten(take<0,3>(dst(Int<0>{}))); + + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +namespace detail { + +/// @brief Creates a TMA descriptor for im2col TMA load. +/// +/// @param tensor_cwhdn Global activation tensor (A matrix of Fprop). +/// This is the original (not im2col-transformed) tensor in global +/// memory. +/// +/// @param slayout Rank 2 (M,K) shared memory layout of the activation +/// tensor. Here, K is "GEMM K," not the filter tensor's mode of +/// the same name. +////// +/// @param traversal_stride Traversal strides convolution parameter +////// +/// Each of padding_shape, traversal_stride, and dilation_shape is a +/// tuple whose size is the number of spatial modes (e.g., 3 for a 5-D +/// convolution). +/// +/// @return TMA descriptor for im2col TMA load +template +CUTE_HOST +auto +make_im2col_tma_copy_desc( + Tensor const& tensor_cwhdn, // (C,W,H,D,N) + uint32_t range_c, // TILE_C + uint32_t range_whdn, // TILE_WHDN + SmemSwizzle const& smem_swizzle, // Swizzle + TMALayout const& tma_layout_vt, // TMA layout + LowerCornerStride const& lower_corner_whd, // WHD offset of the "base pointer" + UpperCornerStride const& upper_corner_whd, // WHD upper corner + LowerPaddingStride const& lower_padding_whd, // WHD lower padding + UpperPaddingStride const& upper_padding_whd, // WHD upper padding + TraversalStride const& stride_whd, // WHD traversal stride + LowerSRTStride const& lower_srt, // SRT offset of the "base pointer" + DilationStride const& stride_srt, // SRT stride - dilation + TMA::DescriptorAuxParams const& aux_params = {}) +{ + static_assert(is_gmem::value, "Tensor must point to GPU global memory."); + using value_type = typename EngineA::value_type; + + constexpr uint32_t num_total_modes = LayoutA::rank; + constexpr int num_spatial_modes = num_total_modes - 2; + + // Gmem starting address + void* gmem_address = (void*) raw_pointer_cast(tensor_cwhdn.data()); + + // Gmem extents are just the tensor shape + cute::array gmem_prob_shape = {1,1,1,1,1}; + for_each(make_seq{}, [&](auto i) { + gmem_prob_shape[i] = static_cast(shape(tensor_cwhdn)); + }); + + // Gmem strides are byte strides of the activation tensor in CWHDN order + cute::array gmem_prob_stride = {0,0,0,0,0}; + for_each(make_seq{}, [&](auto i) { + gmem_prob_stride[i] = sizeof(value_type) * stride(tensor_cwhdn); + }); + + // Traversal strides are a function of the dilation shape + // corresponding to spatial (WHD) modes. + cute::array tma_traversal_strides = {1,1,1,1,1}; + for_each(make_seq{}, [&](auto i) { + tma_traversal_strides[i+1] = static_cast(get(stride_whd)); + }); + + cute::array tma_lower_corner{}; + for_each(make_seq{}, [&](auto i) { + tma_lower_corner[i] = static_cast(get(lower_corner_whd)); + }); + + cute::array tma_upper_corner{}; + for_each(make_seq{}, [&](auto i) { + tma_upper_corner[i] = static_cast(get(upper_corner_whd)); + }); + + Im2ColTmaDescriptor tma_desc; + +#if (__CUDACC_VER_MAJOR__ >= 12) + + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); + CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + CUtensorMapL2promotion tma_l2Promotion = to_CUtensorMapL2promotion(aux_params.l2promo_); + CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_); + TMA::SmemSwizzleBits swizzle_bits = detail::get_tma_swizzle_bits(smem_swizzle); + TMA::SmemSwizzleBase swizzle_base = detail::get_tma_swizzle_base(smem_swizzle); + CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(swizzle_bits, swizzle_base); + + CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)( + &tma_desc, + tma_format, + num_total_modes, + gmem_address, + gmem_prob_shape.data(), + gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly sizeof(value_type) + tma_lower_corner.data(), + tma_upper_corner.data(), + range_c, + range_whdn, + tma_traversal_strides.data(), + tma_interleave, + tma_swizzle, + tma_l2Promotion, + tma_oob_fill); + + // The extra asserts help indicate the error's cause. + assert(encode_result != CUDA_ERROR_DEINITIALIZED); + assert(encode_result != CUDA_ERROR_NOT_INITIALIZED); + assert(encode_result != CUDA_ERROR_INVALID_CONTEXT); + assert(encode_result != CUDA_ERROR_INVALID_VALUE); + assert(encode_result == CUDA_SUCCESS); + +#endif // (__CUDACC_VER_MAJOR__ >= 12) + // + // Calculate gemm shapes and linearized shapes based on tma layout tiling. + // + + // Compute [w, h, d, n] + // q/p/z = (w/h/d + (upper_corner_whd - lower_corner_whd - 1)) / stride_whd + 1 + auto gemm_mn_ = cute::transform(cute::make_seq{}, [&](auto i) { + return (shape(tensor_cwhdn) + get(upper_corner_whd) - get(lower_corner_whd) - Int<1>{}) / get(stride_whd) + Int<1>{}; + }); + auto gemm_mn = append(gemm_mn_, shape(tensor_cwhdn)); + + // Compute [c, s, r, t] + // fprop/wgrad, s/r/t = 1 + (upper_padding_whd - upper_corner_whd) / stride_srt + // wgrad, s/r/t = 1 + (lower_padding_whd - lower_corner_whd) / stride_srt + auto gemm_k_ = cute::transform(cute::make_seq{}, [&](auto i) { + auto padding_size = conditional_return(get(stride_srt) > Int<0>{}, + get(upper_padding_whd) - get(upper_corner_whd), + get(lower_corner_whd) - get(lower_padding_whd)); + return Int<1>{} + padding_size / get(stride_srt); + }); + auto gemm_k = prepend(gemm_k_, shape<0>(tensor_cwhdn)); + + // For fprop/dgrad kernel, gemm_shapes is ((q, p, z, n), (c, s, r, t)) + // For wgrad kernel, gemm_shapes is ((c, s, r, t), (q, p, z, n)) + auto gemm_shapes_common = make_shape( + transform_leaf(gemm_mn, [](auto s) { + return conditional_return(cute::is_static{}, s, cutlass::FastDivmod(s)); + }), + gemm_k); + auto gemm_shapes = make_shape( + basis_get(stride<0,1>(tma_layout_vt), gemm_shapes_common), + basis_get(stride<0,0>(tma_layout_vt), gemm_shapes_common)); + + // For fprop/dgrad kernel, linearized shapes is (whdn, (c, s, r, t)) + // For wgrad kernel linearized shapes is ((c, s, r, t), whdn) + auto linear_shapes_common = make_shape(size(gemm_mn), gemm_k); + auto linear_shapes = make_shape( + basis_get(stride<0,1>(tma_layout_vt), linear_shapes_common), + basis_get(stride<0,0>(tma_layout_vt), linear_shapes_common)); + + // + // Calculate gmem basis stride based on tma layout tiling. + // + + auto tma_basis_scale = make_shape(Int<1>{}, stride_whd, Int<1>{}, stride_srt); + auto tma_basis = elem_scale(tma_basis_scale, make_basis_like(tma_basis_scale)); + + auto gbasis_strides_common = make_stride( + append(get<1>(tma_basis), get<2>(tma_basis)), + prepend(get<3>(tma_basis), get<0>(tma_basis))); // ((w,h,d,n),(c,s,r,t)) + auto gbasis_strides = make_stride( + basis_get(stride<0,1>(tma_layout_vt), gbasis_strides_common), + basis_get(stride<0,0>(tma_layout_vt), gbasis_strides_common)); + + // + // Create tma tensor + // + + auto lower_corner = make_arithmetic_tuple(Int<0>{}, lower_corner_whd, Int<0>{}, lower_srt); + + auto tensor_multimode = make_tensor(ArithmeticTupleIterator(lower_corner), gemm_shapes, gbasis_strides); + auto tensor_linear = make_identity_tensor(linear_shapes); + auto tma_tensor = make_tensor(tensor_multimode.data(), composition( + tensor_multimode.layout(), + tensor_linear(Int<0>{}), + tensor_linear.layout())); + + return cute::make_tuple(tma_desc, tma_tensor); +} + +template +CUTE_HOST_RTC +auto +make_tma_atom_im2col(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor: ((w, h, d, n), c) + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + int32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map, // V: CTA val idx -> gmem mode + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, // dilation + TMA::DescriptorAuxParams const& aux_params = {}) +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + + // + // TMA slayout manipulation + // + + // Invert the smem to get the largest contiguous vector in the smem layout + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // trunc_smem_idx -> trunc_smem_coord + + // Map from smem idx to a gmem mode + auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout)); + +#if 0 + print("g_layout : "); print(gtensor.layout()); print("\n"); + print("s_layout : "); print(slayout); print("\n"); + print("cta_t_map : "); print(cta_t_map); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); + print("inv_smem : "); print(inv_smem_layout); print("\n"); + print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + + // Generate a TupleBasis for the gtensor + auto glayout_basis = make_identity_layout(product_each(shape(gtensor))); + + // Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc + auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode)); + + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(tma_layout_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank >= 2, "IM2COL expects at least 2 modes of the smem to vectorize with gmem."); + // IM2COL uses a maximum of 2 modes + constexpr int smem_tma_rank = cute::min(int(smem_rank), 2); + + // Keep only the static-1 basis modes into gmem + auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full); + + // Split according to the portion each multicast CTA will be responsible for + auto tma_layout_vt = logical_divide(tma_layout_trunc, safe_div(size(tma_layout_trunc), num_multicast)); + +#if 0 + print("glayout_basis : "); print(glayout_basis); print("\n"); + print("tma_layout_full : "); print(tma_layout_full); print("\n"); + + print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n"); + print("tma_layout_vt : "); print(tma_layout_vt); print("\n"); +#endif + + auto range_c = size<0,0>(tma_layout_vt); + auto range_whdn = size<0,1>(tma_layout_vt); + Tensor gtensor_cwhdn = make_tensor(gtensor.data(), + flatten(make_layout(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.shape()), + basis_get(stride<0,0>(tma_layout_vt), gtensor.stride())), + make_layout(basis_get(stride<0,1>(tma_layout_vt), gtensor.shape()), + basis_get(stride<0,1>(tma_layout_vt), gtensor.stride()))))); + auto [tma_desc, tma_tensor] = make_im2col_tma_copy_desc( + gtensor_cwhdn, + range_c, + range_whdn, + get_swizzle_portion(slayout), + tma_layout_vt, + lower_corner_whd, + upper_corner_whd, + lower_padding_whd, + upper_padding_whd, + stride_whd, + lower_srt, + stride_srt, + aux_params); + + // + // Construct the Copy_Traits + // + + using T = typename GEngine::value_type; + constexpr int num_bits_per_tma = decltype(size(tma_layout_trunc))::value * sizeof(T) * 8; + + using Traits = Copy_Traits, decltype(tma_tensor)>; + using Atom = Copy_Atom; + +#if 0 + print("num_bits : "); print(num_bits_per_tma); print("\n"); +#endif + + Traits tma_traits{tma_desc, tma_tensor}; + + // Return the Copy_Atom + return Atom{tma_traits}; +} + +/// Make a TiledCopy for im2col TMA load. +/// +/// @param copy_op The copy implementation: either +/// SM90_TMA_LOAD_IM2COL or SM90_TMA_LOAD_IM2COL_MULTICAST. +/// +/// @param tensor_cwhdn The global tensor to use for im2col TMA loads. +/// For Fprop convolutions, this is the activation tensor. This is +/// the "original tensor that points to global memory, not the +/// coordinate (im2col-transformed) tensor. +/// +/// @param slayout Layout of shared memory tile. +/// +/// @param stride_whd The traversal strides convolution +/// parameter. +/// +/// @return TiledCopy specialization for im2col TMA loads. +template +CUTE_HOST_RTC +auto +make_tma_copy_im2col(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Layout const& cta_t_map, // CTA tid -> logical TMA tid + Layout const& cta_v_map, // CTA vid -> gmem coord + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, // dilation + TMA::DescriptorAuxParams const& aux_params = {}) +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{}, + "Number of active CTAs in TMA must divide domain size of slayout."); + + Copy_Atom atom = make_tma_atom_im2col(copy_op, gtensor, slayout, cosize(cta_t_map), cta_v_map, + lower_corner_whd, upper_corner_whd, lower_padding_whd, + upper_padding_whd, stride_whd, lower_srt, stride_srt, aux_params); + + // + // Construct the TiledCopy + // + + auto cta_tiler = product_each(shape(cta_v_map)); + + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value>(); + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); + // Scale that up to cover all of the smem_coords + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), safe_div(num_elems_per_tma, cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); + +#if 0 + print("cta_tiler : "); print(cta_tiler); print("\n"); + print("layout_v : "); print(layout_v); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("layout_t : "); print(layout_t); print("\n"); + print("layout_T : "); print(layout_T); print("\n"); + print("layout_TV : "); print(layout_TV); print("\n"); +#endif + + return TiledCopy{atom}; +} + +/// Make a TiledCopy for im2col TMA with no offsets. +/// E.g. im2col TMA load for C and im2col TMA store for D. +template +CUTE_HOST_RTC +auto +make_tma_copy_im2col(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Layout const& cta_t_map, // CTA tid -> logical TMA tid + Layout const& cta_v_map) // CTA vid -> gmem coord +{ + constexpr int num_spatial_modes = rank<0>(GLayout{}) - 1; + return make_tma_copy_im2col(copy_op, gtensor, slayout, cta_t_map, cta_v_map, + append(Stride<_0>{}, Int<0>{}), // lower_corner_whd + append(Stride<_0>{}, Int<0>{}), // upper_corner_whd + append(Stride<_0>{}, Int<0>{}), // lower_padding_whd + append(Stride<_0>{}, Int<0>{}), // upper_padding_whd + append(Stride<_1>{}, Int<1>{}), // stride_whd + append(Stride<_0>{}, Int<0>{}), // lower_srt + append(Stride<_1>{}, Int<1>{})); // stride_srt +} + +} // namespace detail + + + +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + MulticastSize const& multicast_size, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + auto cta_v_tile = make_identity_layout(product_each(shape(tensor_cwhdn))).compose(cta_tiler); + auto cta_t_tile = make_layout(multicast_size); + + return detail::make_tma_copy_im2col(copy_op, tensor_cwhdn, + slayout, cta_t_tile, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// Explicit default for multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, cta_tiler, Int<1>{}, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// Explicit default for cta_tiler and multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, product_each(shape(slayout)), Int<1>{}, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// No offsets copy. +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + MulticastSize const& multicast_size) +{ + auto cta_v_tile = make_identity_layout(product_each(shape(tensor_cwhdn))).compose(cta_tiler); + auto cta_t_tile = make_layout(multicast_size); + + return detail::make_tma_copy_im2col(copy_op, tensor_cwhdn, slayout, cta_t_tile, cta_v_tile); +} + +// Explicit default for multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, cta_tiler, Int<1>{}); +} + +// Explicit default for cta_tiler and multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, product_each(shape(slayout)), Int<1>{}); +} + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3be30af795241037a47cad0f41b5e46da3108382 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp @@ -0,0 +1,1593 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include +#include + +#include + +#include + +#include + +namespace cute +{ + +template +struct AuxTmaParams { + using GmemStrides = GmemTmaBasisStrides_; // Strides for Gmem mode -> Tma coord mode, may be dynamic + GmemStrides g_stride_; + using TmaGmemBasis = TmaGmemBasis_; // Layout for Tma box shape -> Gmem mode(s), always static + static_assert(is_static::value); + using TmaSwizzle = TmaSwizzle_; // Tma swizzle, always Swizzle + static_assert(is_static::value); +}; + +// Utility for unpacking TMA_LOAD arguments into a CopyOp +template +struct TMA_LOAD_Unpack +{ + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); + + auto src_coord = src.data().coord_; + void* dst_ptr = cute::raw_pointer_cast(dst.data()); +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD /////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {}; + +// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {&tma_desc_, &tma_mbar, static_cast(cache_hint)}; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {new_tma_desc, &tma_mbar, static_cast(cache_hint)}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; + + // Construct with updated TMA descriptor only (no barrier change) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {*new_tma_desc, aux_params_}; + } +}; + +// The executable SM90_TMA_LOAD with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint64_t // cache hint + > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) + : opargs_(desc, mbar, cache) {} + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } +}; + +// The prefetch for SM90_TMA_LOAD with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD::PREFETCH arguments + tuple const opargs_; + + // Construct with any other Traits' TMA Desc + template + CUTE_HOST_DEVICE + Copy_Traits(OtherTraits const& traits) + : opargs_({traits.get_tma_descriptor()}) {} + + // Construct directly with a TMA descriptor pointer + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc) + : opargs_({desc}) {} + + // Build a new Prefetch traits with a different TMA descriptor pointer + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {new_tma_desc}; + } + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + auto src_coord = src.data().coord_; + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {}; + +// The non-executable SM90_TMA_LOAD_MULTICAST with tma_desc and no tma_mbar +// Use .with(tma_mbar, multicast_mask) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t, // multicast mask + uint64_t // cache hint + > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint16_t mask, uint64_t hint) + : opargs_(desc, mbar, mask, hint) {} + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_STORE ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_PTR : SM90_TMA_STORE {}; + +// The executable SM90_TMA_STORE with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Construct new TMA_STORE with (unsafe) swapped out TMA descriptor ptr (for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {new_tma_desc}; + } + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor + + void const* const desc_ptr = &(traits.tma_desc_); + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = dst.data().coord_; +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +// Same as SM90_TMA_STORE, but with an unsafe TMA Desc PTR instead +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE arguments + TmaDescriptor const* tma_desc_; + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor + + void const* const desc_ptr = traits.tma_desc_; + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = dst.data().coord_; +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_REDUCE_ADD ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// The executable SM90_TMA_REDUCE_ADD with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_REDUCE_ADD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + template + CUTE_HOST_DEVICE constexpr + void + copy_unpack_(void const* const src_ptr, + Coord const& dst_coord, seq) const + { +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + + SM90_TMA_REDUCE_ADD::copy(&tma_desc_, + src_ptr, get(dst_coord)...); + } + + // This is the copy_unpack dispatch for this Copy_Traits + // Src needs to be a smem tensor + // Dst needs to be a gmem tensor with TmaCoordIterator .data() + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_REDUCE_ADD"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_REDUCE_ADD"); // TMA spoofed src tensor + + traits.copy_unpack_(cute::raw_pointer_cast(src.data()), dst.data().coord_, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// BULK COPY ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +template +struct Copy_Traits +{ + static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0, + "Bulk Copy requires copy vector size align to 16B."); + + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_BULK_COPY_G2S arguments + // 0: uint64_t* bulk_load_memory_barrier + cute::tuple bulk_load_mbar_; + + // Record the memory barrier for the instruction + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& bulk_mbar) const { + return {&bulk_mbar}; + } + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_same, cute::tuple>::value, + "Extra arguments not set. Set .with() before use."); + static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_COPY_G2S"); + static_assert(is_smem::value, "Expected smem dst for SM90_BULK_COPY_G2S"); + SM90_BULK_COPY_G2S::copy(raw_pointer_cast(src.data()), get<0>(traits.bulk_load_mbar_), + raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +template +struct Copy_Traits + : Copy_Traits +{ + template + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) {} + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_PREFETCH"); + SM90_BULK_COPY_G2S::PREFETCH::copy(raw_pointer_cast(src.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +template +struct Copy_Traits +{ + static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0, + "Bulk Copy requires copy vector size align to 16B."); + + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_BULK_COPY_S2G"); + static_assert(is_gmem::value, "Expected gmem dst for SM90_BULK_COPY_S2G"); + SM90_BULK_COPY_S2G::copy(raw_pointer_cast(src.data()), raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +// +// Placeholder for the bulk copy algorithm's default, auto-vectorizing behavior +// + +template +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0,_0>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0,_0>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_UBULK_COPY arguments + // 0: uint64_t* bulk_load_memory_barrier [if this is a BULK_LOAD_G2S] + cute::tuple opargs_; + + // Record the memory barrier for the instruction + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& bulk_mbar) const { + return {&bulk_mbar}; + } +}; + +// +// MAKE_TMA_COPY and related +// + +namespace detail { + +// Custom version of coalesce that greedily combines modes only up to size-256 +// Look at each element and the back of the stack (in order of priority) +// back(NewLayout) get(OldLayout) +// s0:d0 _1:d1 => continue +// _1:d0 s1:d1 => replace_back s1:d1 +// s0:d0 s1:s0*d0 => replace_back s0*s1:d0 if s0*s1 <= 256 +// s0:d0 s1:d1 => append s1:d1 +// +// @pre OldShape and OldStride are flat +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256_impl(OldShape const& old_shape, OldStride const& old_stride, + NewShape const& new_shape, NewStride const& new_stride) +{ + if constexpr (I == rank_v) { + // Base case, we're done + if constexpr (is_constant<1, NewShape>::value) { + return Layout<_1,_0>{}; + } else { + return Layout{new_shape,new_stride}; + } + } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { + // shape(layout) == _1, skip it and continue + return coalesce_256_impl(old_shape, old_stride, new_shape, new_stride); + } else if constexpr (is_constant<1, NewShape>::value) { + // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) + return coalesce_256_impl(old_shape, old_stride, get(old_shape), get(old_stride)); + } else if constexpr (is_constant(old_stride) && + get(old_shape) * back(new_shape) <= Int<256>{})>::value) { + // Merge modes because the shapes and strides match and the merge is 256 or less + return coalesce_256_impl(old_shape, old_stride, + replace_back(new_shape, get(old_shape) * back(new_shape)), + new_stride); + } else { + // Can't replace or merge, so append a new mode + return coalesce_256_impl(old_shape, old_stride, + append(new_shape, get(old_shape)), + append(new_stride, get(old_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +// Combine all the modes that are possible to combine +// Does not respect the profile of the layout, but does preserve total size +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + return coalesce_256_impl<1>(flat_shape, flat_stride, get<0>(flat_shape), get<0>(flat_stride)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +construct_tma_gbasis(Tensor const& gtensor, // The original GMEM Tensor + Layout const& slayout, // The layout of SMEM + Layout const& cta_v_map) // smem_idx to hier gmode +{ + // + // TMA parameter checking + // + + // CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + // "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + CUTE_STATIC_ASSERT_V(size(slayout) == size(cta_v_map), + "TMA requires CTA_Tile and SLayout top-level size equivalence."); + +#if 0 + print("gtensor : "); print(gtensor); print("\n"); + print("slayout : "); print(slayout); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); +#endif + + // + // TMA slayout manipulation + // + + // Invert the smem to get the largest contiguous vector in the smem layout + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + + // Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode + // smem idx -> gmem mode + auto sidx2gmode_full = coalesce(composition(cta_v_map, inv_smem_layout)); + +#if 0 + print("inv_smem_layout : "); print(inv_smem_layout); print("\n"); + print("sidx2gmode_full : "); print(sidx2gmode_full); print("\n"); +#endif + + // + // TMA gtensor truncation + // + + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(sidx2gmode_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?"); + + // Keep only the static-1 basis modes into gmem + auto sidx2gmode = take<0,smem_rank>(sidx2gmode_full); + +#if 0 + print("smem_rank : "); print(smem_rank); print("\n"); + print("sidx2gmode : "); print(sidx2gmode); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + + // The smem vector is the same units as gtensor, so compose first and then recast + // tma_val_idx:gmem_strides + auto tile_gstride = recast(gtensor.compose(sidx2gmode)).layout(); + // Coalesce modes up to size-256 (the maximum TMA box extent in units of TmaInternalType) + // tma_box_shape:gmem_strides + auto tma_gstride = coalesce_256(tile_gstride); + + // Perform the tiling, recast, and coalesce to the gmem vector again, but with indirections to the gtensor modes + auto gbasis = make_identity_layout(shape(gtensor)); + auto tile_gbasis_tmp = gbasis.compose(sidx2gmode); + + // Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape + // tma_box_shape:gmem_mode + auto tile_gbasis = make_layout(shape(tile_gstride), stride(tile_gbasis_tmp)); + + // "Coalesce" the tile basis into a compatible shape with the tma_gstride + auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride)))); + + // Recast the original tensor for shape/stride inspections + Tensor gtensor_T = recast(gtensor); + + // Find missing bases that don't appear in tile_gbasis + auto tile_gbasis_remaining_stride = filter_tuple(flatten(shape (gtensor_T)), flatten(stride(gtensor_T)), + flatten(stride(gbasis)), + [&](auto s, auto d, auto e) + { + if constexpr (is_constant<1, decltype(s)>::value || is_constant<0, decltype(d)>::value) { + return cute::tuple<>{}; // If size-1 or stride-0, then don't append + } else { + using E = decltype(e); + auto has_e = any_of(flatten(stride(tma_gbasis_tile)), [] (auto tb) { return tb == E{}; }); + if constexpr (decltype(has_e)::value) { + return cute::tuple<>{}; // If d was found, then don't append + } else { + return cute::tuple(e); // Else, this is missing so append + } + } + }); + + // Append the remaining basis modes that contribute to the TMA with size-1 + auto tile_gbasis_remaining_shape = repeat(Int<1>{}); + auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(tile_gbasis_remaining_shape )), + tuple_cat(wrap(stride(tma_gbasis_tile)), wrap(tile_gbasis_remaining_stride))); + + // Group the trailing modes to make this max rank-5 -- TMA rank limitation + // tma_box_shape:gmem_mode + auto tma_gbasis = group(tma_gbasis_full); + +#if 0 + print("tile_gstride : "); print(tile_gstride); print("\n"); + print("tma_gstride : "); print(tma_gstride); print("\n"); + print("gbasis : "); print(gbasis); print("\n"); + print("tile_gbasis : "); print(tma_gbasis_tile); print("\n"); + print("tma_gbasis : "); print(tma_gbasis); print("\n"); +#endif + + return tma_gbasis; +} + +template +CUTE_HOST_DEVICE constexpr +void +fill_tma_gmem_shape_stride(Tensor const& gtensor, // Gmem Shapes and Strides, in units of TmaInternalType + TmaGmemBasisStride const& tma_gbasis_stride, // Map Tma mode idx -> Gmem mode(s) + cute::array & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t + cute::array & gmem_prob_stride) // Tma Strides +{ + static_assert(is_tuple::value); + static_assert(is_same::value || is_same::value); + + using TmaInternalType = typename GEngine::value_type; + constexpr int tma_rank = decltype(rank(tma_gbasis_stride))::value; + static_assert(TmaRank >= tma_rank); + + auto gmem_shape = shape(gtensor); + auto gmem_stride = stride(gtensor); + // Use the indirections in tma_gbasis_stride into gtensor to construct the tma gmem shapes/strides + for_each(make_seq{}, [&](auto i) { + constexpr int tma_i_rank = decltype(rank(tma_gbasis_stride))::value; + if constexpr (tma_i_rank == 1) { + // Trivial contribution of this gmem mode to this tma mode + auto ej = unwrap(get(tma_gbasis_stride)); + gmem_prob_shape[i] = basis_get(ej, gmem_shape); + gmem_prob_stride[i] = basis_get(ej, gmem_stride); + } else { + // Apply a recurrence to each gmem mode that contributes to this tma mode + for_each(get(tma_gbasis_stride), [&](auto ej) { + // Problem shape + uint64_t shape_j = basis_get(ej, gmem_shape); + // Problem stride (in bytes) + uint64_t stride_j = basis_get(ej, gmem_stride); + uint64_t old_stride = gmem_prob_stride[i]; + gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j); + + if (gmem_prob_stride[i] != 0) { + // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 + gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i]) + + (shape_j-1) * (stride_j / gmem_prob_stride[i]) + + 1; + } else { + gmem_prob_shape[i] = shape_j; + } + }); + } + }); +} + +// Overload for an existing Copy_Traits +template +CUTE_HOST_DEVICE constexpr +void +fill_tma_gmem_shape_stride(Copy_Traits const& tma_traits, + Tensor const& gtensor, // Gmem Shapes and Strides, value_type = TmaInternalType + cute::array & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t + cute::array & gmem_prob_stride) // Tma Strides +{ + return fill_tma_gmem_shape_stride(gtensor, stride(typename Aux::TmaGmemBasis{}), + gmem_prob_shape, gmem_prob_stride); +} + +// Use a sidx2gmode to read through the GMEM tensor +// and construct a TMA Descriptor for the resulting instruction +// At the same time, construct the Tma Tensor's Stride to generate +// the TMA coordinates that the instruction consumes. +// +template +CUTE_HOST_RTC +auto +make_tma_copy_desc(Tensor const& gtensor, // The original GMEM Tensor + Layout const& tma_gbasis, // TMA mode -> GMEM mode mapping + Swizzle const& swizzle, // Swizzle fn on smem_idx + uint32_t num_multicast) // The number of CTAs in multicasting +{ + // + // TMA desc creation + // + + constexpr int tma_dim = decltype(rank(tma_gbasis))::value; + + // + // TMA gmem desc info + // + + // Recast the original tensor for shape/stride inspections + Tensor gtensor_T = recast(gtensor); + + void* gmem_address = (void*) raw_pointer_cast(gtensor_T.data()); + auto gmem_layout = gtensor_T.layout(); + + cute::array gmem_prob_shape = {1,1,1,1,1}; + cute::array gmem_prob_stride = {0,0,0,0,0}; + + fill_tma_gmem_shape_stride(gtensor_T, stride(tma_gbasis), gmem_prob_shape, gmem_prob_stride); + + assert((reinterpret_cast(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned + + assert(gmem_prob_shape[0] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[0] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[1] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[1] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[2] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[2] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[3] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[3] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32 + + // TMA descriptor does not store the zeroth stride and assumes it is 1 (TmaInternalType element). + assert(gmem_prob_stride[0] == 1 && "Majorness of smem doesn't match majorness of gmem"); + + // convert strides to byte strides + for(uint64_t& stride : gmem_prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + + // Assert the byte strides. Tma Descriptor uses byte strides + assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[2] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[3]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[3] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[4]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[4] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + + // + // TMA smem desc info + // + + cute::array smem_box_shape = {1,1,1,1,1}; + cute::array smem_box_stride = {1,1,1,1,1}; + // The smem box is simply given by the sizes of the modes in tma_gbasis + for_each(make_seq{}, [&](auto i) { + smem_box_shape[i] *= size(tma_gbasis); + }); + // Finally, truncate the tma box by the num_multicast + for (uint32_t i = tma_dim-1, multicast = num_multicast; multicast > 1; --i) { + assert(smem_box_shape[i] % multicast == 0 || multicast % smem_box_shape[i] == 0); + uint32_t new_mult = ceil_div(multicast, smem_box_shape[i]); + smem_box_shape[i] = ceil_div(smem_box_shape[i], multicast); + multicast = new_mult; + } + + assert(smem_box_shape[0] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[0] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[1] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[1] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[2] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[2] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[3] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[3] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[4] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[4] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + + assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + + // + // Construct the descriptor + // + + TmaDescriptor tma_desc{}; + + // + // TMA general info + // + + #if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); + CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + // TMA smem swizzle type + TMA::SmemSwizzleBits swizzle_bits = get_tma_swizzle_bits(swizzle); + TMA::SmemSwizzleBase swizzle_base = get_tma_swizzle_base(swizzle); + CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(swizzle_bits, swizzle_base); + CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &tma_desc, + tma_format, + tma_dim, + gmem_address, + gmem_prob_shape.data(), + gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1 + smem_box_shape.data(), + smem_box_stride.data(), + tma_interleave, + smem_swizzle, + tma_l2Promotion, + tma_oobFill); + + if (result != CUDA_SUCCESS) { + std::cerr << "TMA Desc Addr: " << &tma_desc + << "\nformat " << tma_format + << "\ndim " << tma_dim + << "\ngmem_address " << gmem_address + << "\nglobalDim " << gmem_prob_shape + << "\nglobalStrides " << gmem_prob_stride + << "\nboxDim " << smem_box_shape + << "\nelementStrides " << smem_box_stride + << "\ninterleave " << tma_interleave + << "\nswizzle " << smem_swizzle + << "\nl2Promotion " << tma_l2Promotion + << "\noobFill " << tma_oobFill << std::endl; + std::cerr << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + assert(false); + } + + #endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + auto recast_ratio = cute::trait_ratio(sizeof_bits{}, + sizeof_bits< TmaInternalType>{}); + + auto gbasis = make_basis_like(shape(gtensor)); + + // Finally, get the inverse permutation of the E bases for the mocked gmem stride + auto gmem_tma_basis_stride = transform_leaf(gbasis, [&](auto ei) { + auto si = basis_get(ei, shape(gmem_layout)); + auto di = basis_get(ei, stride(gmem_layout)); + if constexpr (is_constant<1, decltype(si)>::value || is_constant<0, decltype(di)>::value) { + return Int<0>{}; // If size-1 or stride-0, return arithmetic identity -- no contribution to the TMA + } else { + auto tma_gmem_basis_stride = stride(tma_gbasis); + // Find j such that E is in stride(tma_gbasis) + using EI = decltype(ei); + [[maybe_unused]] auto j = find_if(tma_gmem_basis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); }); + if constexpr (decltype(j == rank(tma_gmem_basis_stride))::value) { + return Int<0>{}; // If not-found, return arithmetic identity -- no contribution to the TMA + } else + if constexpr (decltype(j == Int<0>{})::value) { + auto scale = recast_ratio * basis_get(ei, stride(gtensor)); + return E{} * scale; // Return TMA Coord basis -- with a recast scale factor + } else + if constexpr (decltype(rank(tma_gmem_basis_stride) == Int<1>{})::value) { + return E{}; // Return TMA Coord basis -- known scale of Int<1>{} + } else { + int32_t scale = ceil_div(int32_t(di * sizeof_bits_v / cute::max(gmem_prob_stride[j], uint64_t{16})), 8); + return E{} * scale; // Return TMA Coord basis -- with a dynamic scale factor + } + } + }); + +#if 0 + print("gmem_tma_basis_stride : "); print(gmem_tma_basis_stride); print("\n"); +#endif + + using AuxParams = AuxTmaParams; + return cute::make_tuple(tma_desc, AuxParams{gmem_tma_basis_stride}); +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_atom(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + uint32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map) // V: CTA val idx -> gmem mode +{ + // + // TMA truncated layout + // + + auto smem_swizzle = get_swizzle_portion(slayout); + auto smem_layout = get_nonswizzle_portion(slayout); + + auto tma_gbasis = detail::construct_tma_gbasis(gtensor, smem_layout, cta_v_map); + + // + // Construct the TMA Desc and the strides of the TMA Tensor + // + + auto [tma_desc, aux_params] = detail::make_tma_copy_desc(gtensor, + tma_gbasis, + smem_swizzle, + num_multicast); + + // + // Construct the Copy_Traits + // + + constexpr int num_bits_per_tma = size(tma_gbasis) * sizeof_bits_v; + using Traits = Copy_Traits, decltype(aux_params)>; + using Atom = Copy_Atom; + + Traits tma_traits{tma_desc, aux_params}; + +#if 0 + print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n"); + print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n"); +#endif + + // Return the Copy_Atom + return Atom{tma_traits}; +} + +// The "logical TMA tid" is a map from the CTA rank to its logical id +// within the instruction. It works like a mask or ordering on the +// CTAs. For non-multicast TMA, all CTAs should map to 0. For +// multicast TMA of size 4, CTAs will be mapped to {0,1,2,3}. +template +CUTE_HOST_RTC +auto +make_tma_copy_tiled(CopyOp const& copy_op, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM + Layout const& cta_t_map, // T: CTA thr idx -> logical TMA tid + Layout const& cta_v_map) // V: CTA val idx -> gmem mode +{ + Copy_Atom atom = make_tma_copy_atom(copy_op, gtensor, slayout, + cosize(cta_t_map), cta_v_map); + + // + // Construct the TiledCopy + // + + [[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map)); + + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value>(); + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); + // Scale that up to cover all of the smem_coords + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), safe_div(num_elems_per_tma, cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); + +#if 0 + print("cta_tiler : "); print(cta_tiler); print("\n"); + print("layout_v : "); print(layout_v); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("layout_t : "); print(layout_t); print("\n"); + print("layout_T : "); print(layout_T); print("\n"); + print("layout_TV : "); print(layout_TV); print("\n"); +#endif + + return TiledCopy{atom}; +} + +} // end namespace detail + +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with. + * This is often the blk_shape that is used to tile the GMEM for CTAs: + * local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor + * @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16 + * defining the multicast size (used to further partition the SMEM) + * Else, static-1 + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + using T = float; + T* gptr = nullptr; + + { + // Simple 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM + auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // GMMA 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM + auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // 3D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM + auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // cuTENSOR 4D + auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM + auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: + // Take 128-elem from m: m0 must divide 128, + // m-last may be predicated + // Take 32-elem from k0, 2-elem from k1 + auto slayout = make_layout(cta_tile); // Col-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{}); + } + * + * Check the TMA box size and desc: + print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + print("TMA desc : "); print(tma.tma_desc_); print("\n"); + * + * Usage: + Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor + Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA + Tensor sA = make_tensor(make_smem_ptr(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor + + auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning + Tensor tAgA = cta_tma.partition_S(gA); // Partition for src + Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst + + copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params + */ +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler, + cluster_size); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + auto cta_t_tile = make_layout(cluster_size); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, + gtensor, slayout, + cta_t_tile, cta_v_tile); + } +} + +// Explicit defaulting +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout) +{ + return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{}); +} + +// Explicit defaulting +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Cluster_Size const& cluster_size) +{ + return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size); +} + +//////////////////////////////////// +// Experimental Make TMA Atom and Partitioner +/////////////////////////////////// + +template > +CUTE_HOST_RTC +auto +make_tma_atom(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size = {}) +{ + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, + gtensor, slayout, + size(cluster_size), cta_v_tile); +} + +// The "VectorCopy Partitioner" for TMA +template +CUTE_DEVICE +auto +tma_partition(Copy_Atom const& copy_atom, + CtaCoord const& cta_coord, + Layout const& cta_layout, // T: CTA coord -> logical multicast id + Tensor const& stensor, // SMEM Tensor (TMATile, Rest...) + Tensor const& gtensor) // GMEM Tensor (TMATile, Rest...) +{ + CUTE_STATIC_ASSERT_V(size<0>(stensor) == size<0>(gtensor)); + + // Invert the smem to get the largest contiguous vector in the smem layout + Layout inv_smem_layout = right_inverse(get_nonswizzle_portion(layout<0>(stensor))); + // Scale that up to cover all of the smem_coords + Layout layout_v = tile_to_shape(make_layout(inv_smem_layout), size<0>(stensor)); + + // Factor out the single-instrucion portion + Layout tma_layout_v = make_layout(Int::NumValSrc>{}); + auto layout_V = make_tile(logical_divide(layout_v, tma_layout_v)); + + // Append with _ until we cover all Rest... modes + auto glayout_V = append(layout_V, _); + auto slayout_V = append(layout_V, _); + // Transform tile mode and coalesce + Tensor gtensor_v = coalesce(gtensor.compose(glayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) + Tensor stensor_v = coalesce(stensor.compose(slayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) + // Offset inside the TMA-mode for the multicast + auto multicast_offset = cta_layout(cta_coord) * (size(tma_layout_v) / cosize(cta_layout)); + auto multicast_coord = make_coord(make_coord(multicast_offset, Int<0>{})); + auto gcoord = append(multicast_coord, Int<0>{}); + auto scoord = append(multicast_coord, Int<0>{}); + + Tensor gresult = domain_offset(gcoord, gtensor_v); + Tensor sresult = domain_offset(scoord, stensor_v); + + return cute::make_tuple(gresult, sresult); +} + +// Explicit defaults for cta_coord and cta_layout +template +CUTE_DEVICE +auto +tma_partition(Copy_Atom const& copy_atom, + Tensor const& stensor, // SMEM Tensor (TMATile, Rest...) + Tensor const& gtensor) // GMEM Tensor (TMATile, Rest...) +{ + return tma_partition(copy_atom, Int<0>{}, Layout<_1,_0>{}, stensor, gtensor); +} + +// TMA Multicast Masks Calculation +template +CUTE_HOST_DEVICE constexpr +uint16_t +create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, + CtaCoord const& cta_coord_vmnk) +{ + auto [cta_layout, elected_cta] = slice_and_offset(cta_coord_vmnk, cta_layout_vmnk); + + uint16_t mcast_mask = 0; + if constexpr (rank_v == 0) { + // Trivial case with no additional ctas + mcast_mask = uint16_t(1); + } else + if constexpr (rank_v == 1 and depth_v <= 1 and + not is_static::value) { + // Get the instruction code -- optimized for dynamic flat-rank-1 cta_layout + mcast_mask = uint16_t(1); + // Smear by stride<0> (may want to predicate on stride<0> mag?) + mcast_mask |= mcast_mask << (1*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (2*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (4*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (8*stride<0>(cta_layout)); + // Select shape<0> + mcast_mask &= (uint16_t(-1) >> (16 - shape<0>(cta_layout) * stride<0>(cta_layout))); + } else { + // Get the instruction code -- generic path + for (int i = 0; i < size(cta_layout); ++i) { + mcast_mask |= uint16_t(1) << cta_layout(i); + } + } + // Shift by the instruction's elected block rank (dynamic) + mcast_mask <<= elected_cta; + return mcast_mask; +} + +// Projections multicast mask +template +CUTE_HOST_DEVICE constexpr +uint16_t +create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, + CtaCoord const& cta_coord_vmnk) +{ + return create_tma_multicast_mask(cta_layout_vmnk, replace(cta_coord_vmnk, _)); +} + +//////////////////////////////////// +// Make TMA copy A/B/C +/////////////////////////////////// + +template +CUTE_HOST_RTC +auto +make_tma_copy_A_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + // Keep only MK modes from MNK + auto cta_tiler_mk = remove<1>(cta_tiler); + + // mcast along N mode for this M load, if any + auto cluster_size_n = size<1>(cluster_size); + + if constexpr (cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_mk, + cluster_size_n); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_mk); + auto cta_t_tile = make_layout(cluster_size_n); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_tile, cta_v_tile); + return tma_copy; + } +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_B_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + // Keep only NK modes from MNK + auto cta_tiler_nk = remove<0>(cta_tiler); + + // mcast along M mode for this N load, if any + auto cluster_size_m = size<0>(cluster_size); + + if constexpr (cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_nk, + cluster_size_m); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_nk); + auto cta_t_tile = make_layout(cluster_size_m); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_tile, cta_v_tile); + return tma_copy; + } +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_C_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler) +{ + // Keep only MN modes from MNK + auto cta_tiler_mn = remove<2>(cta_tiler); + + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_mn, + _1{}); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_mn); + + // No multicast, so only 1 CTA involved + auto cta_t_map = Layout<_1,_0>{}; + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); + return tma_copy; + } +} +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..90d616df2e234b1ed8425303b9b1e5309cbf1f32 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -0,0 +1,114 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/// @file copy_traits_sm90_tma_swizzle.hpp +/// @brief Functions for converting swizzle layout to TMA descriptor + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include + +namespace cute::detail { + +template +CUTE_HOST_DEVICE constexpr +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Swizzle) +{ + if constexpr (M == 4) { + static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + if constexpr (B == 3) { return TMA::SmemSwizzleBits::B128; } + if constexpr (B == 2) { return TMA::SmemSwizzleBits::B64; } + if constexpr (B == 1) { return TMA::SmemSwizzleBits::B32; } + if constexpr (B == 0) { return TMA::SmemSwizzleBits::DISABLE; } + } else + + if constexpr (M == 5 || M == 6) { + static_assert(B == 2, "Expected B = 2 when M == 5 or 6. Unsupported layout swizzle."); + // S-condition as well? + return TMA::SmemSwizzleBits::B128; + } else + + { + static_assert(M < 0, "Unsupported layout swizzle."); + } +} + +template +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Layout const& layout) +{ + return get_tma_swizzle_bits(get_swizzle_portion(layout)); +} + +template +CUTE_HOST_DEVICE constexpr +TMA::SmemSwizzleBase +get_tma_swizzle_base(Swizzle) +{ + if constexpr (M == 4) { + static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + static_assert(S == 3, "Expected S = 3 when M == 4. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_16B; + } + + else if constexpr (M == 5) { + static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle."); + static_assert(S == 2, "Expected S = 2 when M == 5. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_32B; + } else if constexpr (M == 6) { + static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_64B; + } + #if 1 + else { + static_assert(4 <= M && M <= 6, "Expected 128b=16B=(2^4)B to 512b=64B=(2^6)B base swizzle."); + } + #else + + else { + static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle."); + } + #endif +} + +template +TMA::SmemSwizzleBase +get_tma_swizzle_base(Layout const& layout) +{ + return get_tma_swizzle_base(get_swizzle_portion(layout)); +} + +} // namespace cute::detail diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_atom.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_atom.hpp new file mode 100644 index 0000000000000000000000000000000000000000..57338e2a4be111e277fdd394ed97192ab2bf0a05 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_atom.hpp @@ -0,0 +1,705 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include +#include + +namespace cute { + +template +struct MMA_Atom; + +template +struct MMA_Atom : MMA_Atom> +{}; + +template +struct MMA_Atom> + : MMA_Traits +{ + using MMA_Op = MMAOperation; + using Traits = MMA_Traits; + + // Element value types from the MMA_Traits + using ValTypeD = typename Traits::ValTypeD; + using ValTypeA = typename Traits::ValTypeA; + using ValTypeB = typename Traits::ValTypeB; + using ValTypeC = typename Traits::ValTypeC; + + // Thr-Val layouts from the MMA_Traits + using Shape_MNK = typename Traits::Shape_MNK; + using ThrID = typename Traits::ThrID; + using LayoutC_TV = typename Traits::CLayout; + using LayoutA_TV = typename Traits::ALayout; + using LayoutB_TV = typename Traits::BLayout; + + // Fragment value types from the MMA_Traits (optional, defaults to Val type) + using FrgTypeD = typename detail::FrgTypeC_or_Default::type; + using FrgTypeA = typename detail::FrgTypeA_or_Default::type; + using FrgTypeB = typename detail::FrgTypeB_or_Default::type; + using FrgTypeC = typename detail::FrgTypeC_or_Default::type; + + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + auto traits = Traits::with(static_cast(args)...); + return MMA_Atom{traits}; + } + + // + // Tensor call interfaces + // + + // Cast, check, and call fma + template + CUTE_HOST_DEVICE constexpr + void + call(Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) const + { + static_assert(DLayout::rank == 1, "Expected rank-1 D tensor"); + static_assert(ALayout::rank == 1, "Expected rank-1 A tensor"); + static_assert(BLayout::rank == 1, "Expected rank-1 B tensor"); + static_assert(CLayout::rank == 1, "Expected rank-1 C tensor"); + + return mma_unpack(static_cast(*this), D, A, B, C); + } + + // Three arguments reproduces C + template + CUTE_HOST_DEVICE constexpr + void + call(Tensor const& A, + Tensor const& B, + Tensor & C) const + { + return call(C, A, B, C); + } + + // + // make_fragment_A|B|C + // These functions are awkward as they expect already-partitioned tensors + // resulting from a previous call to partition_A|B|C + // The reasoning is that we can inspect the layout of the partitioned data + // and attempt to match it in generated fragment to promote vectorization + // when copying from partition to fragment. + // + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_C(CTensor&& ctensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<3>{}); // VMN + CUTE_STATIC_ASSERT_V(size<0>(ctensor) == size<1>(LayoutC_TV{})); + // C is a bit special because we are after accumulators here + // The input/output type doesn't have to match the accumulator type + //static_assert(std::is_same::value_type>::value, "Expecting ValTypeC type"); + + // We'll never base the accumulator layout on the input tensor layout, so just return a FrgTypeC tensor + return make_tensor(shape(ctensor)); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_A(ATensor&& atensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<3>{}); // VMK + CUTE_STATIC_ASSERT_V(size<0>(atensor) == size<1>(LayoutA_TV{})); + + if constexpr (has_dereference::value) { + // If the intended FrgTypeA is a view (of the current tensor), forward the whole + static_assert(is_same::value_type>::value + || (sizeof_bits_v::value_type> == 8 && + (sizeof_bits_v == 8 || sizeof_bits_v == 6 || sizeof_bits_v == 4)) + || (sizeof_bits_v::value_type> == 4 && + (sizeof_bits_v == 4 || sizeof_bits_v == 3 || sizeof_bits_v == 2)) + , "Expecting ValTypeA type"); + return make_tensor(static_cast(atensor)); + } else { + // Else, the intended FrgTypeA is a value type, construct a new tensor with a fragment layout + return make_fragment_like(atensor); + } + + CUTE_GCC_UNREACHABLE; + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_B(BTensor&& btensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<3>{}); // VNK + CUTE_STATIC_ASSERT_V(size<0>(btensor) == size<1>(LayoutB_TV{})); + + if constexpr (has_dereference::value) { + // If the intended FrgTypeB is a view (of the current tensor), forward the whole + static_assert(is_same::value_type>::value + + || (sizeof_bits_v::value_type> == 8 && + (sizeof_bits_v == 8 || sizeof_bits_v == 6 || sizeof_bits_v == 4)) + + , "Expecting ValTypeB type"); + return make_tensor(static_cast(btensor)); + } else { + // Else, the intended FrgTypeB is a value type, construct a new tensor with a fragment layout + return make_fragment_like(btensor); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// +// A tiling of mma atoms +// + +template +struct ThrMMA; + +// @tparam MMA_Atom The MMA_Atom to use in the TiledMMA +// @tparam AtomLayoutMNK The MNK-tiling of the Atom to be performed. +// @tparam PermuationsMNK Permutations to apply to each MNK-mode before tiling for the Atom. +template > +struct TiledMMA : MMA_Atom +{ + using Atom = MMA_Atom; + using AtomShape_MNK = typename MMA_Atom::Shape_MNK; + using AtomThrID = typename MMA_Atom::ThrID; + using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV; + using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV; + using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV; + + static_assert( rank_v == 3, "TiledMMA requires rank-3 AtomLayoutMNK"); + static_assert( rank_v == 3, "TiledMMA requires rank-3 PermutationMNK"); + static_assert( is_tuple::value, "TiledMMA requires independent permutations of MNK."); + static_assert(is_static::value, "TiledMMA requires static permutations of MNK."); + + using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{})); + ThrLayoutVMNK thr_layout_vmnk_; + + CUTE_HOST_DEVICE constexpr + TiledMMA(MMA_Atom const& mma_atom = {}, AtomLayoutMNK const& thr_layout_mnk = {}) + : MMA_Atom(mma_atom), + thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)) {} + + CUTE_HOST_DEVICE constexpr auto + get_thr_layout_vmnk() const { + return thr_layout_vmnk_; + } + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx + // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx + // FrgV: The values local to an MMA. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_C(CTensor&& ctensor) const + { + CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{}); + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(permutation_mnk<0>(), + permutation_mnk<1>()); + auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) + + // Tile the tensor for the Atom + auto c_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<1>(AtomShape_MNK{}))); + auto c_tensor = zipped_divide(t_tensor, c_tile); // ((AtomM,AtomN),(RestM,RestN)) + + // Transform the Atom mode from (M,N) to (Thr,Val) + auto tv_tensor = c_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) + + // Tile the tensor for the C-threads + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk_)), + make_layout(size<2>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN))) + + return thr_tensor; + } + + // Tile a tensor or a layout from shape + // (M,K,...) + // to shape + // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx + // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx + // FrgV: The values local to an MMA. + // RestM: The values tiled in M. + // RestK: The values tiled in K. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_A(ATensor&& atensor) const + { + CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{}); + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(permutation_mnk<0>(), + permutation_mnk<2>()); + auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk_)), + make_layout(size<3>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + // Tile a tensor or a layout from shape + // (N,K,...) + // to shape + // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx + // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx + // FrgV: The values local to an MMA. + // RestN: The values tiled in N. + // RestK: The values tiled in K. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_B(BTensor&& btensor) const + { + CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{}); + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(permutation_mnk<1>(), + permutation_mnk<2>()); + auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto b_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto b_tensor = zipped_divide(t_tensor, b_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (N,K) to (Thr,Val) + auto tv_tensor = b_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk_)), + make_layout(size<3>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + + return thr_tensor; + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_slice(ThrIdx const& thr_idx) const + { + auto thr_vmnk = thr_layout_vmnk_.get_flat_coord(thr_idx); + return ThrMMA{*this, thr_vmnk}; + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_thread_slice(ThrIdx const& thr_idx) const + { + return get_slice(thr_idx); + } + + // + // Utility for printing and visualization + // + + // The permutation applied to the MNK-mode data + template + CUTE_HOST_DEVICE constexpr + auto + permutation_mnk() const { + static_assert(0 <= I && I < 3); + auto perm = get(PermutationMNK{}); + return conditional_return(is_underscore{}, size(AtomShape_MNK{}) * size(get_thr_layout_vmnk()), perm); + } + + // The size of the MNK-mode + template + CUTE_HOST_DEVICE constexpr + auto + tile_size_mnk() const { + static_assert(0 <= I && I < 3); + return size(permutation_mnk()); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutC_TV() const + { + // (M,N) -> (M,N) + auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); + + // (thr_idx,val) -> (M,N) + return thrfrg_C(ref_C).compose(thridx_2_thrid, _); + } + + + CUTE_HOST_DEVICE constexpr + auto + get_layoutA_TV() const + { + // (M,K) -> (M,K) + auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); + + // (thr_idx,val) -> (M,K) + return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutB_TV() const + { + // (N,K) -> (N,K) + auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); + + // (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); + + // (thr_idx,val) -> (N,K) + return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _); + } +}; + +template +struct ThrMMA : TiledMMA +{ + ThrVMNK thr_vmnk_; + + template + CUTE_HOST_DEVICE constexpr + auto + partition_C(CTensor&& ctensor) const + { + auto thr_tensor = make_tensor(static_cast(ctensor).data(), this->thrfrg_C(ctensor.layout())); + + auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_))); + return thr_tensor(thr_vmn, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_A(ATensor&& atensor) const + { + auto thr_tensor = make_tensor(static_cast(atensor).data(), this->thrfrg_A(atensor.layout())); + + auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_B(BTensor&& btensor) const + { + auto thr_tensor = make_tensor(static_cast(btensor).data(), this->thrfrg_B(btensor.layout())); + + auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_C(CTensor&& ctensor) const + { + return TiledMMA::make_fragment_C(partition_C(ctensor)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_A(ATensor&& atensor) const + { + return TiledMMA::make_fragment_A(partition_A(atensor)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_B(BTensor&& btensor) const + { + return TiledMMA::make_fragment_B(partition_B(btensor)); + } +}; + +// +// These tile the MMA_Atom as a whole +// + +template >, + class Permutations = Tile> +CUTE_HOST_DEVICE constexpr +auto +make_tiled_mma(MMA_Atom const& mma_atom, + MMAThrLayout const& thr_layout = {}, + Permutations const& permutations = {}) +{ + auto thr_layout_mnk = append<3>(thr_layout, Layout<_1,_0>{}); + auto permutation_mnk = append<3>(permutations, _); + + return TiledMMA, + decltype(thr_layout_mnk), + decltype(permutation_mnk)>{mma_atom, thr_layout_mnk}; +} + +template >, + class Permutations = Tile> +CUTE_HOST_DEVICE constexpr +auto +make_tiled_mma(MMA_Op const&, + MMAThrLayout const& thr_layout = {}, + Permutations const& permutations = {}) +{ + // Attempt to wrap in an MMA_Atom<> and forward + return make_tiled_mma(MMA_Atom{}, thr_layout, permutations); +} + +// +// partition_fragment_C -- static context +// + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_C(TiledMMA const& mma, Shape_MN const& shape_MN) +{ + auto dummy = make_layout(shape(shape_MN)); + auto dummy_tv = mma.thrfrg_C(dummy); + // Slice+rearrange like partition_C + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + +} + + +template +CUTE_HOST_DEVICE constexpr +auto +partition_fragment_C(TiledMMA const& mma, Shape_MN const& shapeMN) +{ + return make_tensor::FrgTypeC>(partition_shape_C(mma, shapeMN)); +} + +// partition_fragment_A and partition_fragment_B often depend on the +// layout of A and B and/or the thread_idx that is requesting the partition. +// For these reasons, they should not be used in a static context. +// See TiledMMA::get_slice(thr_idx).partition_fragment_A(tensorA) instead. + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_A(TiledMMA const& mma, Shape_MK const& shape_MK) +{ + auto dummy = make_layout(shape(shape_MK)); + auto dummy_tv = mma.thrfrg_A(dummy); + // Slice+rearrange like partition_A + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + +} + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_B(TiledMMA const& mma, Shape_NK const& shape_NK) +{ + auto dummy = make_layout(shape(shape_NK)); + auto dummy_tv = mma.thrfrg_B(dummy); + // Slice+rearrange like partition_B + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + +} + +// +// Size +// + +template +CUTE_HOST_DEVICE constexpr +auto +tile_size(TiledMMA const& mma) +{ + return mma.template tile_size_mnk(); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_shape(TiledMMA const& mma) +{ + return make_shape(tile_size<0>(mma), tile_size<1>(mma), tile_size<2>(mma)); +} + +// Deprecate? +template +CUTE_HOST_DEVICE constexpr +auto +size(TiledMMA const& mma) +{ + return size(mma.get_thr_layout_vmnk()); +} + +// Alias +template +CUTE_HOST_DEVICE constexpr +auto +thr_size(TiledMMA const& mma) +{ + return size(mma.get_thr_layout_vmnk()); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE +void +print(MMA_Atom> const&) +{ + using Atom = MMA_Atom>; + print("MMA_Atom\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" Shape_MNK: "); print(typename Atom::Shape_MNK{}); print("\n"); + print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n"); + print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n"); + print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n"); +} + +template +CUTE_HOST_DEVICE +void +print(TiledMMA const& mma) +{ + print("TiledMMA\n"); + print(" ThrLayoutVMNK: "); print(mma.get_thr_layout_vmnk()); print("\n"); + print(" PermutationMNK: "); print(TiledPerm{}); print("\n"); + print(static_cast(mma)); +} + +template +CUTE_HOST_DEVICE +void +print(ThrMMA const& thr_mma) +{ + print("ThrMMA\n"); + print(" Thr VMNK: "); print(thr_mma.thr_vmnk_); print("\n"); + print(static_cast(thr_mma)); +} + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..986b1ae0a02079bc4da158c8a7aabc9dc791ab1c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits.hpp @@ -0,0 +1,189 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // cute::Tensor +#include // cute::is_rmem +#include // cute::UniversalFMA +#include // cute::detail::explode + +namespace cute +{ + +/** + * concept MMA_Traits + * { + * using ValTypeD = // Logical D-value type + * using ValTypeA = // Logical A-value type + * using ValTypeB = // Logical B-value type + * using ValTypeC = // Logical C-value type (NOTE: Not used? Assumed == ValTypeD) + * + * using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA) + * using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB) + * using FrgTypeC = // C_type consumed by MMA (if ommitted, same as ValTypeC) + * + * using Shape_MNK = // Logical MxNxK shape of the MMA + * + * using ThrID = // Logical thread id (tid) -> tidx + * + * using ALayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MK-coord + * using BLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat NK-coord + * using CLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MN-coord + * }; + */ + +template +struct MMA_Traits +{ + static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation."); +}; + +template +struct MMA_Traits> +{ + using ValTypeD = D; + using ValTypeA = A; + using ValTypeB = B; + using ValTypeC = C; + + // Logical shape of the MMA + using Shape_MNK = Shape<_1,_1,_1>; + + // Logical thread id (tid) -> tidx + using ThrID = Layout<_1>; + + // (Logical thread id (tid), Logical value id (vid)) -> coord + + // (tid,vid) -> (m,k) + using ALayout = Layout>; + // (tid,vid) -> (n,k) + using BLayout = Layout>; + // (tid,vid) -> (m,n) + using CLayout = Layout>; +}; + +// Extract an MMA_Op from an MMA_Traits +template +struct MMA_Op {}; + +template +struct MMA_Op> { + using type = MMA_Op_Arg; +}; + +// +// Generic mma_unpack for any MMA_Traits +// + +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(AnyMMATraits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using MMA_Op = typename MMA_Op::type; + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + Tensor rA = recast(A); + Tensor rB = recast(B); + Tensor rD = recast(D); + Tensor rC = recast(C); + + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMA_Op::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(AnyMMATraits const& traits, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + mma_unpack(traits, D, A, B, C); +} + +namespace detail { + +template +struct FrgTypeA_or_Default { using type = typename X::ValTypeA; }; +template +struct FrgTypeA_or_Default> { using type = typename X::FrgTypeA; }; + +template +struct FrgTypeB_or_Default { using type = typename X::ValTypeB; }; +template +struct FrgTypeB_or_Default> { using type = typename X::FrgTypeB; }; + +template +struct FrgTypeC_or_Default { using type = typename X::ValTypeC; }; +template +struct FrgTypeC_or_Default> { using type = typename X::FrgTypeC; }; + +} // end namespace detail + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm100.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd4d15e789510d9219ebd45e235c02e4e420dc89 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm100.hpp @@ -0,0 +1,4032 @@ +/*************************************************************************************************** + * Copyright (c) 2022 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include // cute::TMEM:: + +#include +#include // cute::GMMA:: +#include // cute::GMMA:: +#include // UTCCP smem desc + +#include + +// Check that aggregate initialization in .with() initializes all fields +#if defined(__GNUG__) +#pragma GCC diagnostic warning "-Wmissing-field-initializers" +#pragma GCC diagnostic error "-Wmissing-field-initializers" +#endif + +namespace cute { + +namespace UMMA { + +////////////////////////////////////////////////// +// Common layouts for UMMA Shared Memory // +////////////////////////////////////////////////// + +using cute::GMMA::Layout_MN_INTER_Atom; +using cute::GMMA::Layout_MN_SW32_Atom; +using cute::GMMA::Layout_MN_SW64_Atom; +using cute::GMMA::Layout_MN_SW128_Atom; +using cute::GMMA::Layout_K_INTER_Atom; +using cute::GMMA::Layout_K_SW32_Atom; +using cute::GMMA::Layout_K_SW64_Atom; +using cute::GMMA::Layout_K_SW128_Atom; + +using Layout_MN_SW128_32B_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _1024>>>; + +template +using Layout_MN_SW128_32B_Atom = decltype(upcast::value>(Layout_MN_SW128_32B_Atom_Bits{})); + +////////////////////////////////////////////////// +// Common layouts for Sparse UMMA Shared Memory // +////////////////////////////////////////////////// + +using cute::GMMA::Layout_MN_INTER_SpAtom; +using cute::GMMA::Layout_MN_SW32_SpAtom; +using cute::GMMA::Layout_MN_SW64_SpAtom; +using cute::GMMA::Layout_MN_SW128_SpAtom; +using cute::GMMA::Layout_K_INTER_SpAtom; +using cute::GMMA::Layout_K_SW32_SpAtom; +using cute::GMMA::Layout_K_SW64_SpAtom; +using cute::GMMA::Layout_K_SW128_SpAtom; + +template +using Layout_MN_SW128_32B_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW128_32B_Atom{}.layout_b()))>; + +// With UMMA::Major param +template +using Layout_INTER_SpAtom = typename conditional, + Layout_K_INTER_SpAtom>::type; +template +using Layout_SW32_SpAtom = typename conditional, + Layout_K_SW32_SpAtom>::type; +template +using Layout_SW64_SpAtom = typename conditional, + Layout_K_SW64_SpAtom>::type; +template +using Layout_SW128_SpAtom = typename conditional, + Layout_K_SW128_SpAtom>::type; + +// Tile a MN-logical layout atom to an MMA Tile Shape ((MMA_M,MMA_N),M_MMAs,N_MMAs,...) +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_mma_shape(LayoutAtom const& atom, MMATileShape const& mma_tile_shape, ModeOrder const& order = {}) +{ + constexpr int R = decltype(rank(mma_tile_shape))::value; + auto mn_shape = cute::tuple_cat(zip(shape<0>(mma_tile_shape), take<1,3>(mma_tile_shape)), take<3,R>(mma_tile_shape)); + auto mn_tiled = tile_to_shape(atom, mn_shape, order); // (BLK_M,BLK_N,...) + return tiled_divide(mn_tiled, product_each(shape<0>(mma_tile_shape))); // ((MMA_M,MMA_N),M_MMAs,N_MMAs,...) +} + +// +// Tensor (position-dependent swizzle) to LayoutType utility +// + +template +CUTE_HOST_DEVICE constexpr +LayoutType +layout_type(Tensor> const&) +{ + static_assert(is_same::value, + "Expected uint128_t type in LayoutType conversion."); + + using Swizzle = get_swizzle_t; + constexpr int B = Swizzle::num_bits; + constexpr int M = Swizzle::num_base; + constexpr int S = Swizzle::num_shft; + + if constexpr (M == 4) { + static_assert(S == 3, "Expected S = 3 when M == 4. Unsupported layout swizzle."); + switch (B) { + default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + case 0: return LayoutType::SWIZZLE_NONE; + case 1: return LayoutType::SWIZZLE_32B; + case 2: return LayoutType::SWIZZLE_64B; + case 3: return LayoutType::SWIZZLE_128B; + } + } else + if constexpr (M == 5) { + static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle."); + static_assert(S == 2, "Expected S = 2 when M == 5. Unsupported layout swizzle."); + return LayoutType::SWIZZLE_128B_BASE32B; + } else { + static_assert(M==5, "Only 16B and 32B Atoms are supported for UMMA. Unsupported layout swizzle."); + return LayoutType::SWIZZLE_NONE; // ERROR + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Construction method for UMMA Descriptors +/////////////////////////////////////////////////////////////////////////////// + +/** +* /////////////////////////////// +* // make_umma_desc // +* /////////////////////////////// +* Each UmmaDescriptor Major-MN describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO)) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) +* LayoutType::128B_BASE32B : Swizzle<2,5,2> o smem_ptr o ((T,8,m),(4,k)):((1,T,LBO),(?T,SBO)) +* +* where +* T : sizeof(uint128_t) / sizeof(value_type) +* m : integer in [1,16] corresponding to UMMA shape +* k : integer in [1,32] corresponding to UMMA shape +* SBO: stride byte offset +* LBO: leading byte offset +* +* See UMMA::Layout_MN_XXX_Atom for building canonical UmmaDescriptor Major-MN layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_umma_desc for appropriate value_type. +* +* ////////////////////////////// +* // make_umma_desc // +* ////////////////////////////// +* Each UmmaDescriptor Major-K describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T )) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) +* +* See UMMA::Layout_K_XXX_Atom for building canonical UmmaDescriptor Major-K layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_umma_desc for appropriate value_type. +*/ +template +CUTE_HOST_DEVICE constexpr +SmemDescriptor +make_umma_desc(Tensor const& tensor) +{ + static_assert(is_smem::value, "UMMA Descriptors can only be constructed on smem."); + static_assert(TLayout::rank == 2, "UMMA Descriptors can only be constructed on rank-2 tensors."); + using value_type = typename TEngine::value_type; + + Tensor u128_tensor = recast(tensor); + + // Result + SmemDescriptor desc; + desc.version_ = 1; // Set the version for blackwell + desc.lbo_mode_ = 0; // set to legacy mode by default + + // Layout type + constexpr UMMA::LayoutType LAYOUT_TYPE = UMMA::layout_type(u128_tensor); + desc.layout_type_ = uint8_t(LAYOUT_TYPE); + + // Start address (4LSB not included) + uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data())); + desc.start_address_ = static_cast(start_address >> 4); + + constexpr uint8_t base_offset = 0; + desc.base_offset_ = base_offset; + + // LayoutType meta + constexpr int SwizzleAtomMNSize = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE ? 1 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_32B ? 2 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_64B ? 4 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_128B ? 8 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 8 : -1; + + if constexpr (MajorMode == UMMA::Major::MN) + { + /* In units of uint128_t, each UmmaDescriptor Major-MN describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((1,n),(8,k)):((X,SBO),(1,LBO)) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((2,n),(8,k)):((1,LBO),(2,SBO)) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) + * LayoutType::B128_BASE32B : Swizzle<2,5,2> o smem_ptr o ((8,n),(4,k)):((1,LBO),(4,SBO)) + */ + + constexpr int SwizzleAtomKSize = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 4 : 8; + + // Construct the canonical UMMA T Layout with shape ((SwizzleAtomMNSize,n),(SwizzleAtomKSize,2)) + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile>,Layout>>{}); + + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical UMMA_MN Layout: Expected profile failure."); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE ? stride<0,0>(canonical_layout) : 1; + static_assert(stride_00 == expected_stride_00, "Not a canonical UMMA_MN Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = SwizzleAtomMNSize; + static_assert(stride_10 == expected_stride_10, "Not a canonical UMMA_MN Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); + + desc.stride_byte_offset_ = (LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE) ? stride_01 : stride_11; + desc.leading_byte_offset_ = (LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE) ? stride_11 : stride_01; + } else + if constexpr (MajorMode == UMMA::Major::K) + { + /* In units of uint128_t, each UmmaDescriptor Major-K describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,n),2):((1,SBO),LBO) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,n),2):((2,SBO),1) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,n),2):((4,SBO),1) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),2):((8,SBO),1) + * LayoutType::B128_BASE32B : Not applicable for Major-K + */ + + static_assert(LAYOUT_TYPE != UMMA::LayoutType::SWIZZLE_128B_BASE32B, "SWIZZLE_128B_BASE32B is invalid for Major-K"); + CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size + "Not a canonical UMMA_K Layout: Expected MN-size multiple of 8."); + + // Construct the canonical UMMA N Layout with shape ((8,n),(2,1)) + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile,Layout<_2,_1>>{}); + + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical UMMA_K Layout: Expected profile failure."); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = SwizzleAtomMNSize; + static_assert(stride_00 == expected_stride_00, "Not a canonical UMMA_K Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE) ? stride<1,0>(canonical_layout) : 1; + static_assert(stride_10 == expected_stride_10, "Not a canonical UMMA_K Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + + desc.stride_byte_offset_ = stride_01; + desc.leading_byte_offset_ = stride_10; + } else { + static_assert(MajorMode != UMMA::Major::MN && MajorMode != UMMA::Major::K, "Unrecognized MajorMode!"); + } + return desc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Higher level UMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +struct DescriptorIterator +{ + using reference = SmemDescriptor; + using element_type = SmemDescriptor; + using value_type = SmemDescriptor; + + SmemDescriptor desc_; + + // Dereference returns the UmmaDescriptor + CUTE_HOST_DEVICE constexpr + reference operator*() const { return desc_; } + + // Advance and return a new UmmaDescriptor + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { return *(*this + i); } + + // Return an advanced iterator + template + CUTE_HOST_DEVICE constexpr + DescriptorIterator operator+(Index const& offset) const + { + // Use 32bit calculation rather than 64 bit calculation as we only update the part of desc + SmemDescriptor ret; + ret.lo = desc_.lo + uint32_t(offset); + ret.hi = desc_.hi; + return { ret }; + } +}; + +template +CUTE_HOST_DEVICE constexpr +SmemDescriptor +raw_pointer_cast(DescriptorIterator const& ptr) { + return ptr.desc_; +} + +CUTE_HOST_DEVICE void +print(DescriptorIterator const&) { + printf("UMMA::DescriptorIterator"); +} + +// Flag for smem descriptor allocation/creation +template +struct smem_desc : DescriptorIterator {}; + +template +struct sparse_smem_desc : DescriptorIterator {}; + +} // end namespace UMMA + +// Customization point for creating a UMMA::smem_desc Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a UMMA Desc Tensor"); + return make_tensor(UMMA::DescriptorIterator{UMMA::make_umma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +// Customization point for creating a UMMA::sparse_smem_desc Tensor +template +struct MakeTensor> +{ + // Note that this is the exact same as UMMA::smem_desc above. + // Only the interface validates that we are passed a sparse_ptr, which is recast away to construct + // the smem desc tensor + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a UMMA Desc Tensor"); + static_assert(is_sparse::value, "Expected sparse value_type."); + static_assert(is_sparse_ptr::value, "Expected sparse iter."); + return make_tensor(UMMA::DescriptorIterator{UMMA::make_umma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +// Special smem_desc_iter tensor entry for UTCCP copy. +template +constexpr auto get_utccp_smem_desc_tensor(Tensor const& smem_utccp_partitioned_tensor) { + using VecLayout = decltype(layout<0>(TLayout{})); + static_assert(VecLayout::rank == 2 && shape<1>(VecLayout{}) == 1, "Mismatched vec_mode tensor."); + static_assert(is_smem::value, "Expect vec_mode smem_tesnor."); + static_assert(is_static::value, "Utccp copy tensor's vec_mode should be static."); + + using value_type = typename TEngine::value_type; + using UtccpTaits = Copy_Traits; + + // UtccpTaits::ValID: logical_bit_idx -> tmem_offset. + // We arrange the logical_bit_idx in order of (core_matrix_strided, core_matrix_leading, repeat(only in 64dplw01), broadcast). + // So we only need the first two modes for src smem_tensor. + auto utccp_core_matrix_shape = take<0,2>(upcast>(typename UtccpTaits::ValID{}).shape()); + // logical_bit_idx -> smem_addr + Layout vec_v_layout = flatten(layout<0>(VecLayout{})); + Layout utccp_core_matrix_layout = vec_v_layout.with_shape(utccp_core_matrix_shape); + Tensor utccp_core_matrix_tensor = group_modes<0,2>(make_tensor(smem_utccp_partitioned_tensor.data(), utccp_core_matrix_layout)); + Tensor core_matrix_desc_tensor = make_tensor>(utccp_core_matrix_tensor); + return make_tensor(core_matrix_desc_tensor.data(), recast_layout(smem_utccp_partitioned_tensor.layout())); +} + +namespace UMMA { + +// Import TMEM constants +namespace TMEM = cute::TMEM; + +enum class TmemAllocMode { + // Default allocation mode. + // If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms can be + // interleaved by using the top-half-subpartition and the bottom-half-subpartition. + // Full utilization of TMEM capacity. + Interleaved = 0, + // Prevents interleaving. + // If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms will not be + // interleaved. + // Required for DP-address equivalence in TMEM-A and TMEM-C allocations in UMMA_TS. + NonInterleaved = 1, + // Duplicates the TMEM allocation across subpartitions. + // E.g. UMMA_2SM_128xNx16_TS uses a "2x2 DP" TMEM Layout, but the TMEM allocation is + // actually doubled and the input data must be duplicated between the + // subpartitions [0,1]<->[2,3], i.e., each subpartition holds all columns + // of the A matrix needed for a single UMMA operation. + // For UMMA_2SM_128xNx16_TS, the distribution of the data is as follows. + // SM0: + // Subpart0 = A[0:32, 0:16], Subpart1 = A[32:64, 0:16], + // Subpart2 = A[A:32, 0:16], Subpart3 = A[32:64, 0:16] + // SM1: + // Subpart0 = A[64:96, 0:16], Subpart1 = A[96:128, 0:16], + // Subpart2 = A[64:96, 0:16], Subpart3 = A[96:128, 0:16] + Duplicated = 2, + // Duplicates the TMEM allocation across subpartitions for scale factor. + // Scale factor TMEM allocation for 4x1 data path + ScaleFactorDuplicated4by1 = 3, + // Scale factor TMEM allocation for 2x2 data path + ScaleFactorDuplicated2by2 = 4 +}; + +struct tmem_frg_base {}; + +// The UMMA Traits below have custom fragment type flags for their tmem tensors. +// These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. +template +struct tmem_frg : tmem_frg_base +{ + static_assert(sizeof_bits_v <= sizeof_bits_v, "TMEM MMA allocations require StorageType big enough for ValueType."); + + // UMMA TMEM Allocator + // Each UMMA expects a specific MxN layout of TMEM for accumulators + // and sometimes a specific MxK layout of TMEM for A-values. + // @tparam ValueType The value type of the TMEM Tensor to allocate. + // @tparam StorageType The storage type of the TMEM Tensor to allocate. + // "Sparse" allocations often allocate ValueType=half_t within StorageType=uint32_t. + // "Dense" allocations often allocate ValueType=half_t within StorageType=half_t. + // @tparam N_SM The number of SMs in this UMMA_XSM instruction. + // @tparam TmemAlloc UMMA-specific allocation modifier for special cases. + // Some UMMA instructions expect strange atoms or tilings of atoms. + // @param tmem_shape ((M_MMA_SM,N_MMA_SM),MMA_M,MMA_N,...) + // The post-MMA-partitioned shape of TMEM to allocate. + // Note for UMMA_2SM_128xNx16, that M_MMA_SM will be 64, for example. + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(size(tmem_shape)*Int)>{} <= TMEM::MAX_CAPACITY_BITS{}, + "Requesting more TMEM than is available."); + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + using COL_ADDR = C::value / sizeof_bits::value>; + Layout tmem_restride = Layout, + Stride, COL_ADDR>>{}; + + static_assert(N_SM == 1 || N_SM == 2, "UMMA expects N_SM == 1 or N_SM == 2"); + if constexpr (N_SM == 1) + { + static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved || TmemAlloc == UMMA::TmemAllocMode::NonInterleaved, + "UMMA_1SM only accepts Interleaved or NonInterleaved"); + static_assert(M_MMA == 64 || M_MMA == 128, "UMMA_1SM M-mode size should be 64 or 128."); + + if constexpr (M_MMA == 64) + { + // Half subpartitions layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout, Int>, + Stride, _128>>{}; + // tile_stride = 2 causes the tiling to "skip" the first tile in DPs + constexpr int tile_stride = TmemAlloc == UMMA::TmemAllocMode::Interleaved ? 1 : 2; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape), + compact_col_major(take<1,R>(tmem_shape),Int{}))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 128) + { + // For M_MMA = 128, all datapaths are occupied. TmemAllocMode doesn't change the allocation. + // Full subpartitions layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } + + } else + if constexpr (N_SM == 2) + { + static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved || TmemAlloc == UMMA::TmemAllocMode::Duplicated, + "UMMA_2SM only accepts Interleaved or Duplicated"); + static_assert(M_MMA == 32 || M_MMA == 64 || M_MMA == 128, "UMMA_2SM M-mode size should be 32 or 64 or 128."); + + if constexpr (M_MMA == 32) + { + static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved, "Only TmemAllocMode::Interleaved is supported for UMMA_2SM M_MMA=32"); + // The "1x4" layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout, _4>>, + Stride< _1,Stride< _128,_32>>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 64 && TmemAlloc == UMMA::TmemAllocMode::Interleaved) + { + // The "2x2" layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout, _2>>, + Stride< _1,Stride< _128,_64>>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + + } else + if constexpr (M_MMA == 64 && TmemAlloc == UMMA::TmemAllocMode::Duplicated) + { + // The "2x2" duplicated layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 128) + { + // For M_MMA = 128, all datapaths are occupied. TmemAllocMode doesn't change the allocation. + // The "4x1" layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// Convenient aliases for common cases in the UMMA::ElementXFrg below +template +using tmem_frg_1sm = tmem_frg; +template +using tmem_frg_2sm = tmem_frg; + +// Make metadata TMEM fragments for sparse MMAs. +// Also note that the TMEM fragment addresses are assumed to be COL-4 aligned -- working with arch to remove this condition +template +struct tmem_e_frg : tmem_frg_base +{ + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + static_assert(M_MMA == 128, "Only 128 implemented right now."); + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + [[maybe_unused]] Layout tmem_restride = Layout, + Stride>{}; + + if constexpr (sizeof_bits::value == 32) // TF32: 128x16 atom + { + static_assert(N_MMA == 16); + Layout tmem_atom = Layout, Shape < _8,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } else + if constexpr (sizeof_bits::value == 16) // FP16: 128x32 atom + { + static_assert(N_MMA == 32); + Layout tmem_atom = Layout, Shape < _16,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } else + if constexpr (sizeof_bits::value == 8) // S8|Mix.F4/F6/F8: 128x64 atom + { + // For Mix 8bit f4/f6/f8, will pass in ValueType = uint8_t + static_assert(N_MMA == 64); + Layout tmem_atom = Layout, + Stride< _1,_128>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + if constexpr (sizeof_bits::value == 4) // F4: 128x128 atom + { + // For F4, will pass in ValueType = fp4 + Layout tmem_restride1 = Layout>, + Stride, _1>>{}; + // F4 has roughly same TMEM layout as Mix8bit.F4/F6/F8, the only difference is that K is multiplied by two + static_assert(N_MMA == 128); + Layout tmem_atom = Layout, + Stride< _1, _128>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride1, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +template +struct tmem_e_frg_ws : tmem_frg_base +{ + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + static_assert(M_MMA == 128 || M_MMA == 64 || M_MMA == 32, "Weight stationary UMMA_1SM M-mode size should be 32 or 64 or 128."); + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + Layout tmem_restride = Layout, + Stride>{}; + + if constexpr (sizeof_bits::value == 32) // TF32 + { + // MMA_M x MMA_K: 128x16 atom / 64x16 atom / 32x16 atom + static_assert(N_MMA == 16); + if constexpr (M_MMA == 128) { + Layout tmem_atom = Layout, Shape < _8,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 64) { + Layout tmem_atom = Layout, Shape < _8,_2>, _2>, + Stride, Stride<_128,_8>,_64>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles its own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 32) { + Layout tmem_atom = Layout, Shape < _8,_2>, _4>, + Stride, Stride<_128,_8>,_32>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles its own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else { + static_assert(dependent_false, "Invalid M_MMA value"); + } + } + else if constexpr (sizeof_bits::value == 16) // FP16 + { + // MMA_M x MMA_K: 128x32 atom / 64x32 atom / 32x32 atom + static_assert(N_MMA == 32); + if constexpr (M_MMA == 128) { + Layout tmem_atom = Layout, Shape < _16,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 64) { + Layout tmem_atom = Layout, Shape < _16,_2>, _2>, + Stride, Stride<_128,_8>,_64>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 32) { + Layout tmem_atom = Layout, Shape < _16,_2>, _4>, + Stride, Stride<_128,_8>,_32>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else { + static_assert(dependent_false, "Invalid M_MMA value"); + } + } + else if constexpr (sizeof_bits::value == 8) // I8|F8 + { + // MMA_M x MMA_K: 128x64 atom / 64x64 atom / 32x64 atom + static_assert(N_MMA == 64); + if constexpr (M_MMA == 128) { + Layout tmem_atom = Layout, + Stride< _1,_128>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 64) { + Layout tmem_atom = Layout>, + Stride< _1, Stride<_128, _64>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 32) { + Layout tmem_atom = Layout>, + Stride< _1, Stride<_128, _32>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else { + static_assert(dependent_false, "Invalid M_MMA value"); + } + } + else { + static_assert(dependent_false, "Invalid ValueType"); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +template +struct tmem_sf_frg: tmem_frg_base +{ + // UMMA TMEM Allocator for Scale Factor A for Mxf4Nvf4 and Mxf8f6f4 instructions + // We expect a tensor that has the same layout as A matrix + // @tparam ValueType: data type of scaling factor + // Note that the StorageType is the same as ValueType, i.e., we always use a compact allocation + // @tparam SFVecSize: The number of values that is scaled by a single scaling factor. + // Valid values are (16, 32) + // @tparam N_SM: Number of SMs in UMMA instruction + // @param tmem_shape: An MMA partitioned shape where first mode encodes, A layout of the MMA instruction. + // Note that the shape doesn't match the actual allocation. size<0,1>(tmem_shape) will give us the number of + // elements in K-mode of MMA rather than the number of scaling factors. + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int MMA_MN = decltype(size<0,0>(tmem_shape))::value; + constexpr int MMA_VS = decltype(size<0,1,0>(tmem_shape))::value; + constexpr int MMA_NSF = decltype(size<0,1,1>(tmem_shape))::value; + constexpr int R_MMA_K = decltype(rank(get<0,1>(tmem_shape)))::value; + constexpr int R = decltype(rank(tmem_shape))::value; + + // We expect an MMA-SF partitioned tensor + // ((MMA_MN, (VecSize, NSF)), num_MMA_MN, num_MMA_K, ...) + // where VecSize*NSF = MMA_K + static_assert(R >= 3, "Expected an MMA partitioned tensor"); // ((MMA), num_MMA_MN, num_MMA_K, ...) + static_assert(R_MMA_K == 2, "Expected an MMA-SF partitioned tensor"); // (VecSize, NSF) + using REP = _4; // Replication factor. Data is always replicated across subpartitions + constexpr int SUBPART_DPs = 32; // Number of DPs in a subpartition + + using COL_ADDR = C::value / sizeof_bits::value>; + Layout tmem_restride = Layout, + Stride, COL_ADDR>>{}; + + if constexpr (Is_SFA || (!Is_SFA && TmemAlloc == UMMA::TmemAllocMode::ScaleFactorDuplicated4by1)) { + // SFA, 2x2 and 4x1 data path + // SFB, 4x1 data path + auto tmem_atom = Layout < Shape< Shape< Shape, Int>, REP>, Shape, Int>>, + Stride, _32>, Stride< _0, _128>>>{}; + + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + auto final_tmem_layout = composition(tmem_restride, tmem_logical_layout); + return make_tensor(make_tmem_ptr(), final_tmem_layout); + } + else { + // SFB, 2x2 datapath + static_assert(!Is_SFA and TmemAlloc == UMMA::TmemAllocMode::ScaleFactorDuplicated2by2); + static_assert(N_SM == 2, "Should be 2x2 Datapath"); + // 2x2 Datapth + auto tmem_atom = Layout < Shape< Shape< Shape, Int>, _2, _2>, Shape, Int>>, + Stride, _64, _32>, Stride< _0, _128>>>{}; + + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + auto final_tmem_layout = composition(tmem_restride, tmem_logical_layout); + return make_tensor(make_tmem_ptr(), final_tmem_layout); + } + } +}; + +// Make C/D Tmem fragment for weight-stationary MMAs +template +struct tmem_frg_ws : tmem_frg_base +{ + static_assert(sizeof_bits_v <= sizeof_bits_v, "TMEM MMA allocations require StorageType big enough for ValueType."); + + // UMMA TMEM Allocator + // Each UMMA expects a specific MxN layout of TMEM for accumulators + // and sometimes a specific MxK layout of TMEM for A-values. + // @tparam ValueType The value type of the TMEM Tensor to allocate. + // @tparam StorageType The storage type of the TMEM Tensor to allocate. + // "Sparse" allocations often allocate ValueType=half_t within StorageType=uint32_t. + // "Dense" allocations often allocate ValueType=half_t within StorageType=half_t. + // @tparam N_SM The number of SMs in this UMMA_XSM instruction. + // @tparam TmemAlloc UMMA-specific allocation modifier for special cases. + // Some UMMA instructions expect strange atoms or tilings of atoms. + // @param tmem_shape ((M_MMA_SM,N_MMA_SM),MMA_M,MMA_N,...) + // The post-MMA-partitioned shape of TMEM to allocate. + // Note for UMMA_2SM_128xNx16, that M_MMA_SM will be 64, for example. + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(size(tmem_shape)*Int)>{} <= TMEM::MAX_CAPACITY_BITS{}, + "Requesting more TMEM than is available."); + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + using COL_ADDR = C::value / sizeof_bits::value>; + Layout tmem_restride = Layout, + Stride, COL_ADDR>>{}; + + static_assert(N_SM == 1, "UMMA.WS expects N_SM == 1"); + + static_assert(M_MMA == 32 || M_MMA == 64 || M_MMA == 128, + "Weight stationary UMMA_1SM M-mode size should be 32 or 64 or 128."); + static_assert(N_MMA == 64 || N_MMA == 128 || N_MMA == 256, + "Dense weight stationary UMMA_1SM N-mode size should be 64 or 128 or 256."); + // Weight Stationary MMA config + if constexpr (M_MMA == 32) + { + // 1x4 datapath + Layout tmem_atom = Layout, _4>>, + Stride< _1, Stride< _128,_32>> + >{}; + constexpr int tile_stride = 1; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape), + compact_col_major(take<1,R>(tmem_shape), Int{}))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 64) + { + // 2x2 datapath + Layout tmem_atom = Layout, _2>>, + Stride< _1, Stride< _128,_64>> + >{}; + constexpr int tile_stride = 1; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape), + compact_col_major(take<1,R>(tmem_shape), Int{}))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 128) + { + // For M_MMA = 128, all datapaths are occupied. TmemAllocMode doesn't change the allocation. + // Full subpartitions layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// Convenient aliases for common cases in the UMMA::ElementXFrg below +template +using tmem_frg_ws_1sm = tmem_frg_ws; + +} // end namespace UMMA + +// Customization point for creating a UMMA::tmem_frg Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_frg::make(shape(tmem_shape)); + } +}; + +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_frg_ws::make(shape(tmem_shape)); + } +}; + + +// Customization point for creating a UMMA::tmem_frg Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_e_frg::make(shape(tmem_shape)); + } +}; + +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_e_frg_ws::make(shape(tmem_shape)); + } +}; + +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_sf_frg::make(shape(tmem_shape)); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_SS supports 32bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_TS supports 32bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_TS supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + static constexpr uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_SS_SCALED::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_TS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + static constexpr uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_TS_SCALED::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 4); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_SS_SPARSE supports 32bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_TF32_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 2); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS_SPARSE supports 16bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_F16BF16_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_2x1SM_SS supports 32bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_2x1SM_TS supports 32bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + constexpr static uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS_SCALED::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + constexpr static uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS_SCALED::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 4); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_2x1SM_SS_SPARSE supports 32bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_TF32_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 2); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_SPARSE supports 16bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_F16BF16_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_SS supports 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_TS supports 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_SS_SPARSE supports 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = 0; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_S8_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_2x1SM_SS supports 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_2x1SM_TS supports 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_2x1SM_SS_SPARSE supports 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, id2, tmem_e); + + SM100_MMA_S8_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_SS supports types with leq 8bit types"); + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(((b_major == UMMA::Major::K) && ((N % 8 == 0) && (8 <= N) && (N <= 256))) || + ((b_major == UMMA::Major::MN) && ((N % 16 == 0) && (16 <= N) && (N <= 256))), + "SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256 when B is K major. \ + SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 16 between 16 and 256 when B is MN major."); + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_SS supports types with leq 8bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + constexpr static int SFVecSize = 32; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_SS_SPARSE supports types with leq 8bit types"); + + // Logical shape-K is always 512bits, transform to units of elements + constexpr static int K = 64; + constexpr static int SFVecSize = 64; + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS_SPARSE; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF8F6F4_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_TS supports types with leq 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_SS_SPARSE supports types with leq 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, id2, tmem_e); + + SM100_MMA_F8F6F4_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types"); + static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(((b_major == UMMA::Major::K) && ((N % 16 == 0) && (16 <= N) && (N <= 256))) || + ((b_major == UMMA::Major::MN) && ((N % 32 == 0) && (32 <= N) && (N <= 256))), + "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256 when B is K major. \ + SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256 when B is MN major."); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_TS supports types with leq 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_SS_SPARSE supports types with leq 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, id2, tmem_e); + + SM100_MMA_F8F6F4_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_2x1SM_SS supports types with leq 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + constexpr static int SFVecSize = 32; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS 64 ? M/2 : M), (round_up(N, 128)), a_major, b_major, + a_neg, b_neg>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF8F6F4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE supports types with leq 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 64; + constexpr static int SFVecSize = 64; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS_SPARSE 64 ? M/2 : M), (round_up(N, 128)), a_major, b_major, + a_neg, b_neg>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4_SS supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 64; + constexpr static int SFVecSize = VS; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert((VS == 32 && ((is_same_v || is_same_v) && + (is_same_v || is_same_v)) + && is_same_v) + || (VS == 16), + "2x mode (VectorSize=32) only supports a_type and b_type=float_e2m1_t or cutlass::type_erased_dynamic_float4_t and sf_type=ue8m0_t"); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4_SS; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + using ValTypeA = sparse_elem<4, uint8_t>; + using ValTypeE = sparse_elem<16, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4NVF4_SS_SPARSE supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 128; + constexpr static int SFVecSize = VS; + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert((VS == 64 && ((is_same_v || is_same_v) && + (is_same_v || is_same_v)) + && is_same_v) + || (VS == 32), + "2x mode (VectorSize=64) only supports a_type and b_type=float_e2m1_t or cutlass::type_erased_dynamic_float4_t and sf_type=ue8m0_t"); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4NVF4_SS_SPARSE; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF4NVF4_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4_2x1SM_SS supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 64; + constexpr static int SFVecSize = VS; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4_SS 64 ? M/2 : M), (round_up(N, 128)), VS, a_major, b_major, + a_neg, b_neg>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + using ValTypeA = sparse_elem<4, uint8_t>; + using ValTypeE = sparse_elem<16, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 128; + constexpr static int SFVecSize = VS; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeA = UMMA::sparse_smem_desc; + // using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert((VS == 64 && ((is_same_v || is_same_v) && + (is_same_v || is_same_v)) + && is_same_v) + || (VS == 32), + "2x mode (VectorSize=64) only supports a_type and b_type=float_e2m1_t or cutlass::type_erased_dynamic_float4_t and sf_type=ue8m0_t"); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4NVF4_SS_SPARSE 64 ? M/2 : M), (round_up(N, 128)), VS, a_major, b_major, + a_neg, b_neg>; + + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; + +/** + * Specialization for a vectorized FMA per thread. + */ +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = float; + using ValTypeB = float; + using ValTypeC = float; + + using Shape_MNK = Shape<_2,_1,_1>; + using ThrID = Layout<_1>; + + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = float; + using ValTypeB = float; + using ValTypeC = float; + + using Shape_MNK = Shape<_1,_2,_1>; + using ThrID = Layout<_1>; + + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +namespace SM103 { + // Common mma_unpack for all MMA_Ops in cute::SM103 +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& zA, + Tensor const& zB, + Tensor const& C) + { + auto [A, next_A, SFA] = unzip_tensor(zA); + auto [B, next_B, SFB] = unzip_tensor(zB); + + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_next_a = next_A[0]; + uint64_t desc_b = B[0]; + uint64_t desc_next_b = next_B[0]; + + auto desc_a_temp = reinterpret_cast(desc_a); + auto desc_next_a_temp = reinterpret_cast(desc_next_a); + desc_a_temp.lbo_mode_ = 1; + desc_a_temp.leading_byte_offset_ = desc_next_a_temp.start_address_; + + auto desc_b_temp = reinterpret_cast(desc_b); + auto desc_next_b_temp = reinterpret_cast(desc_next_b); + desc_b_temp.lbo_mode_ = 1; + desc_b_temp.leading_byte_offset_ = desc_next_b_temp.start_address_; + + uint32_t tmem_c = raw_pointer_cast(D.data()); + UMMA::InstrDescriptorBlockScaled instr_desc = traits.idesc_; + instr_desc.k_size_ = 1; + auto tsfa_addr = raw_pointer_cast(SFA.data()); + auto tsfb_addr = raw_pointer_cast(SFB.data()); + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(instr_desc, tsfa_addr, tsfb_addr); + // print("a: "); print(A); print("\n"); + // print("b: "); print(B); print("\n"); + + MMA_Op::fma(reinterpret_cast(desc_a_temp), reinterpret_cast(desc_b_temp), tmem_c, uint32_t(traits.accumulate_), idesc, tsfa_addr, tsfb_addr); + } +} // end namespace SM103 + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 96; + constexpr static int SFVecSize = VS; + + static_assert(a_major == UMMA::Major::K && b_major == UMMA::Major::K, "This MMA does not support transpose"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + using MMA_ScaleFactor = SM100_MMA_MXF4_SS; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 96; + constexpr static int SFVecSize = VS; + + static_assert(a_major == UMMA::Major::K && b_major == UMMA::Major::K, "This MMA does not support transpose"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + using MMA_ScaleFactor = SM100_MMA_MXF4_SS 64 ? M/2 : M), (round_up(N, 128)), VS, a_major, b_major, + a_neg, b_neg>; + + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e3399801cf667deb9589f2df6ce1c2aca32bb368 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120.hpp @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cute +{ + +namespace SM120::BLOCKSCALED { + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A_zipped, + Tensor const& B_zipped, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + using RegTypeSFA = typename remove_extent::type; + using RegTypeSFB = typename remove_extent::type; + + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + constexpr int RegNumSFA = extent::value; + constexpr int RegNumSFB = extent::value; + + auto [A, SFA] = unzip_tensor(A_zipped); + auto [B, SFB] = unzip_tensor(B_zipped); + + using Shape_MNK = typename MMA_Traits::Shape_MNK; + constexpr int SFVecSize = MMA_Traits::SFVecSize; + + // Assert logical size + CUTE_STATIC_ASSERT_V(size(SFA) == size<2>(Shape_MNK{})); + CUTE_STATIC_ASSERT_V(size(SFB) == size<2>(Shape_MNK{})); + + // Assert physical size + CUTE_STATIC_ASSERT(decltype(cosize(layout(SFA))){} == size<2>(Shape_MNK{}) / SFVecSize); + CUTE_STATIC_ASSERT(decltype(cosize(layout(SFB))){} == size<2>(Shape_MNK{}) / SFVecSize); + + Tensor rA = recast(A); + Tensor rB = recast(B); + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + Tensor rD = recast(D); + Tensor rC = recast(C); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + Tensor rSFA = recast(filter_zeros(SFA)); + Tensor rSFB = recast(filter_zeros(SFB)); + + CUTE_STATIC_ASSERT_V(size(rSFA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rSFB) == Int{}); + + detail::explode(MMAOp::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rSFA, make_int_sequence{}, + rSFB, make_int_sequence{}); +} +} // namespace SM120::BLOCKSCALED + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA F8F6F4 16x8x32 TN +template +struct MMA_Traits> + : MMA_Traits +{ + // The MMA accepts 8-bit inputs regardless of the types for A and B + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + + using ValTypeD = c_type; + using ValTypeC = c_type; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA MXF8F6F4 16x8x64 TN +template +struct MMA_Traits> +{ + // The MMA accepts 4-bit inputs regardless of the types for A and B + using ValTypeA = uint4_t; + using ValTypeB = uint4_t; + + using ValTypeD = c_type; + using ValTypeC = c_type; + + using ValTypeSF = sf_type; + constexpr static int SFVecSize = VS; + + using Shape_MNK = Shape<_16,_8,_64>; + using ThrID = Layout<_32>; + + // (T32,V32) -> (M16,K64) + using ALayout = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_16,_8,_512>>>; + // (T32,V16) -> (M16,K64) + using BLayout = Layout,Shape <_8, _2>>, + Stride,Stride<_8,_256>>>; + // (T32,V64) -> (M16,K64) + using SFALayout = Layout,_64>, // Effectively 16 threads due to the 2:0 mode + Stride,_16>>; + // (T32,V64) -> (N8,K64) + using SFBLayout = Layout,_64>, // Effectively 8 threads due to the 4:0 mode + Stride, _8>>; + // (T32,V4) -> (M16,N8) + using CLayout = SM80_16x8_Row; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA MXF8F6F4 16x8x32 TN +template +struct MMA_Traits> +{ + using UnderlyingTraits = MMA_Traits>; + + // The MMA accepts 8-bit inputs regardless of the types for A and B + using ValTypeA = typename UnderlyingTraits::ValTypeA; + using ValTypeB = typename UnderlyingTraits::ValTypeB; + + using ValTypeD = typename UnderlyingTraits::ValTypeD; + using ValTypeC = typename UnderlyingTraits::ValTypeC; + + using Shape_MNK = typename UnderlyingTraits::Shape_MNK; + using ThrID = typename UnderlyingTraits::ThrID; + + using ALayout = typename UnderlyingTraits::ALayout; + using BLayout = typename UnderlyingTraits::BLayout; + using CLayout = typename UnderlyingTraits::CLayout; + + // Scaling factor + using ValTypeSF = sf_type; + constexpr static int SFVecSize = VS; + + // (T32,V32) -> (M16,K32) + using SFALayout = Layout,_32>, // Effectively 16 threads due to the 2:0 mode + Stride,_16>>; + // (T32,V32) -> (N8,K32) + using SFBLayout = Layout,_32>, // Effectively 8 threads due to the 4:0 mode + Stride, _8>>; +}; + +// Transform if needed +template +CUTLASS_DEVICE void +fp4_shift_A(MMA_Op const& op, Tensor&& tensor) { +} +template +CUTLASS_DEVICE void +fp4_shift_B(MMA_Op const& op, Tensor&& tensor) { +} + +// For SM120 MMA F8F6F4 input fp4, the operand A/B are load from ld.matrix. +// ld.matrix b4x16_p64 places FP4 data at the first four bits in each +// eight-bit container, whereas MMA F8F6F4 expects the four-bit data to be in +// the middle of the eight-bit container. Thus, e2m1 operands being fed +// to MMA F8F6F4 must be shifted left by two bits. +// 0b0000ABCD --> 0b00ABCD00 +// NOTE: Same transformation is NOT needed for FP6 and FP8. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120_16x8x32_TN const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120_16x8x32_TN const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +namespace SM120::BLOCKSCALED { + +// Template function with scale factor needs to enmuerate types one by one, as template +// arguments contatins two variadic lists, which cannot be deduced in one shot. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120::BLOCKSCALED::SM120_16x8x32_TN_VS const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120::BLOCKSCALED::SM120_16x8x32_TN_VS const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b62c576dfb6e0f67ed76c8a95818a7000ff43e52 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cute +{ + +namespace { + +// (T32,V4) -> (M16,N8) +using SM120_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +namespace SM120::BLOCKSCALED::SPARSE +{ + +// Unpack explode/mma call with sparse and block scalaring inputs. +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const&, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + using SFARegisters = typename MMAOp::SFARegisters; + using SFBRegisters = typename MMAOp::SFBRegisters; + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + using RegTypeSFA = typename remove_extent::type; + using RegTypeSFB = typename remove_extent::type; + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + constexpr int RegNumSFA = extent::value; + constexpr int RegNumSFB = extent::value; + + auto [tA, tSFA, tE] = unzip_tensor(A); + auto [tB, tSFB ] = unzip_tensor(B); + Tensor rA = recast(tA); + Tensor rE = recast(tE); + Tensor rB = recast(tB); + Tensor rD = recast(D); + Tensor rC = recast(C); + Tensor rSFA = recast(tSFA); + Tensor rSFB = recast(tSFB); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + CUTE_STATIC_ASSERT_V(size(filter_zeros(rSFA)) == Int{}); + CUTE_STATIC_ASSERT_V(size(filter_zeros(rSFB)) == Int{}); + + detail::explode(MMAOp::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}, + rSFA, make_int_sequence{}, + rSFB, make_int_sequence{}); +} + +} // end namespace SM120::BLOCKSCALED::SPARSE + + +namespace SM120::SPARSE +{ + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const&, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + auto [tA, tE] = unzip_tensor(A); + Tensor rA = recast(tA); + Tensor rE = recast(tE); + Tensor rB = recast(B); + Tensor rD = recast(D); + Tensor rC = recast(C); + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMAOp::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}); +} + +} // end namespace SM120::SPARSE + +// sparse F8F6F4 without block-scaling +template +struct MMA_Traits> +{ + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using FrgTypeA = sparse_elem<2, uint8_t>; + using FrgTypeE = sparse_elem<8, uint8_t>; + + using ValTypeC = c_type; + using ValTypeD = c_type; + + using Shape_MNK = Shape<_16, _8, _64>; + using ThrID = Layout<_32>; + // (T32,V32) -> (M16,K64) + using ALayout = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_16,_8,_512>>>; + // (T32,V16) -> (N8,K64) + using BLayout = Layout,Shape <_4, _4>>, + Stride,Stride<_8,_128>>>; + // (T32,V4) -> (M16,N8) + using CLayout = SM120_16x8_Row; + + // (T32, V32) -> (M16, K64) + using ELayout = Layout, _32>, + Stride,_16>>; +}; + +// sparse MXF8F6F4 with block-scaling. +template +struct MMA_Traits> + : MMA_Traits> +{ + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using FrgTypeA = sparse_elem<2, uint8_t>; + using FrgTypeE = sparse_elem<8, uint8_t>; + + using ValTypeD = c_type; + using ValTypeC = c_type; + + using ValTypeSF = sf_type; + constexpr static int SFVecSize = VS; + + using UnderlyingSFTraits = MMA_Traits>; + using SFALayout = typename UnderlyingSFTraits::SFALayout; + using SFBLayout = typename UnderlyingSFTraits::SFBLayout; +}; + +template +struct MMA_Traits> +{ + using ValTypeA = sparse_elem<4, uint8_t>; + using ValTypeE = sparse_elem<16, uint8_t>; + using ValTypeB = uint4_t; + using FrgTypeA = sparse_elem<4, uint8_t>; + using FrgTypeE = sparse_elem<16, uint8_t>; + + using ValTypeC = c_type; + using ValTypeD = c_type; + + using ValTypeSF = sf_type; + + constexpr static int SFVecSize = VS; + + using Shape_MNK = Shape<_16, _8, _128>; + using ThrID = Layout<_32>; + // (T32,V64) -> (M16,K128) + using ALayout = Layout,Shape <_16,_2, _2>>, + Stride,Stride<_16,_8,_1024>>>; + // (T32,V32) -> (N8,K128) + using BLayout = Layout,Shape <_8, _4>>, + Stride,Stride<_8,_256>>>; + // (T32,V128) -> (M16,K128) + using SFALayout = Layout,_128>, + Stride, _16>>; + // (T32,V128) -> (N8,K128) + using SFBLayout = Layout,_128>, + Stride, _8>>; + // (T32,V4) -> (M16,N8) + using CLayout = SM120_16x8_Row; + // (T32, V64) -> (M16, K128) + using ELayout = Layout, Shape< _64>>, + Stride,Stride<_16>>>; +}; + +namespace SM120::SPARSE { + +// For SM120 MMA F8F6F4 input fp4, the operand A/B are load from ld.matrix. +// ld.matrix b4x16_p64 places FP4 data at the first four bits in each +// eight-bit container, whereas MMA F8F6F4 expects the four-bit data to be in +// the middle of the eight-bit container. Thus, e2m1 operands being fed +// to MMA F8F6F4 must be shifted left by two bits. +// 0b0000ABCD --> 0b00ABCD00 +// NOTE: Same transformation is NOT needed for FP6 and FP8. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120_SPARSE_16x8x64_TN const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120_SPARSE_16x8x64_TN const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +} // end namespace SM120::SPARSE + +namespace SM120::BLOCKSCALED::SPARSE { + +// Template function with scale factor needs to enmuerate types one by one, as template +// arguments contatins two variadic lists, which cannot be deduced in one shot. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120_SPARSE_16x8x64_TN_VS const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120_SPARSE_16x8x64_TN_VS const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +} // end namespace SM120::BLOCKSCALED::SPARSE + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm61.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm61.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6b12903b987a6113e43a71ec197915a7d814649f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm61.hpp @@ -0,0 +1,73 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_1,_1,_4>; + using ThrID = Layout<_1>; + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int16_t; + using ValTypeB = int16_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_1,_1,_2>; + using ThrID = Layout<_1>; + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm70.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm70.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0b5b5300b0828067fb9d66c4b6e4640bb54d2bbd --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm70.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +namespace { + +// Logical thread id to thread idx (quadpair) +using SM70_QuadPair = Layout, + Stride<_1,_16>>; +// (T8,V4) -> (M8,K4) +using SM70_8x4_Row = Layout, + Stride<_1,_8>>; +// (T8,V4) -> (M8,K4) +using SM70_8x4_Col = Layout,_4>, + Stride,_1>>; +// (T8,V8) -> (M8,N8) +using SM70_8x8_16b = Layout, + Stride<_1,_8>>; +// (T8,V8) -> (M8,N8) +using SM70_8x8_32b = Layout,Shape <_2,_2, _2>>, + Stride,Stride<_8,_2,_32>>>; + +} + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm75.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm75.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d60c65fe59fb73a248874dfc23232c49d23ad4ec --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm75.hpp @@ -0,0 +1,81 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + using BLayout = Layout,_2>, + Stride,_8>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,_4>, + Stride,_8>>; + using BLayout = Layout,_4>, + Stride,_8>>; + using CLayout = Layout,_2>, + Stride,_8>>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm80.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm80.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f7d5d2fc3005d1e12fdd066cf0d80709de0e5234 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm80.hpp @@ -0,0 +1,690 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include + +namespace cute +{ + +namespace { + +// (T32,V1) -> (M8,N8) +using SM80_8x4 = Layout,_1>, + Stride,_0>>; +// (T32,V2) -> (M8,N8) +using SM80_8x8_Row = Layout,_2>, + Stride,_8>>; +// (T32,V4) -> (M8,N16) +using SM80_8x16_Row = Layout,_4>, + Stride,_8>>; +// (T32,V4) -> (M16,N8) +using SM80_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp16 = fp16 * fp16 + fp16 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = SM80_16x8_Row; + using BLayout = SM80_8x8_Row; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_16,_8,_128>>>; + using BLayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = fp16 * fp16 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = bf16 * bf16 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = tf32 * tf32 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = cutlass::tfloat32_t; + using ValTypeB = cutlass::tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = Layout,_2>, + Stride,_8>>; + using BLayout = SM80_8x4; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = cutlass::tfloat32_t; + using ValTypeB = cutlass::tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _2>, + Stride,_32>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = SM80_8x4; + using BLayout = SM80_8x4; + using CLayout = SM80_8x8_Row; +}; + +// Custom complex fp64 MMA composed of 4 fp64 MMAs -- same layouts +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s8 * s8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = SM80_8x16_Row; + using BLayout = SM80_8x16_Row; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2>>, + Stride,Stride<_16,_8>>>; + using BLayout = SM80_8x16_Row; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_32>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_16,_8,_256>>>; + using BLayout = Layout, Shape <_4, _2>>, + Stride, Stride<_8,_128>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s8 * u8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u8 * s8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u8 * u8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V8) -> (M8,N32) + using ALayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V16) -> (M16,N32) + using ALayout = Layout, Shape < _8, _2>>, + Stride, Stride<_16, _8>>>; + // (T32,V8) -> (M8,N32) + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _64>; + using ThrID = Layout<_32>; + // (T32,V32) -> (M16,N64) + using ALayout = Layout, Shape < _8, _2, _2>>, + Stride, Stride<_16, _8, _512>>>; + // (T32,V16) -> (M8,N64) + using BLayout = Layout, Shape <_8, _2>>, + Stride, Stride<_8, _256>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = b1 ^ b1 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_256>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape<_32,_2,_2>>, + Stride,Stride<_16,_8,_2048>>>; + using BLayout = Layout,Shape<_32,_2>>, + Stride,Stride< _8,_1024>>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = b1 & b1 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template<> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_128>; + using ThrID = Layout<_32>; + using ALayout = Layout,_32>, + Stride,_8>>; + using BLayout = Layout,_32>, + Stride,_8>>; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + :MMA_Traits {}; + +template<> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_128>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape<_32,_2>>, + Stride,Stride>>>; + using BLayout = Layout,_32>, + Stride,_8>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + :MMA_Traits {}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm89.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm89.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f56aab4cb1d642ddadf9f80aac487409aaa23b8b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm89.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// + +// +#pragma once + +#include +#include +#include +#include + +namespace cute +{ + +namespace { + +// (T32,V4) -> (M16,N8) +using SM80_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +template <> +struct MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_32>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_16,_8,_256>>>; + using BLayout = Layout,Shape <_4, _2>>, + Stride,Stride<_8,_128>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e4m3_t; + using ValTypeB = cutlass::float_e4m3_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e4m3_t; + using ValTypeB = cutlass::float_e5m2_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e5m2_t; + using ValTypeB = cutlass::float_e5m2_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e5m2_t; + using ValTypeB = cutlass::float_e4m3_t; + using ValTypeC = cutlass::half_t; +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0467dec3e8cb688f68d86468e005ed2c4bd1216a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90.hpp @@ -0,0 +1,144 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute { + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +using SM90_16x8x4_F64F64F64F64_TN = SM90::MMA_16x8x4_F64F64F64F64_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = Layout,_2>, + Stride,_8>>; + using BLayout = Layout,_1>, + Stride,_0>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +using SM90_16x8x8_F64F64F64F64_TN = SM90::MMA_16x8x8_F64F64F64F64_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _2>, + Stride,_32>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +using SM90_16x8x16_F64F64F64F64_TN = SM90::MMA_16x8x16_F64F64F64F64_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _4>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _4>, + Stride,_32>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +/////////////////////////////////////////////////////////////////////////////////// +//////////////////////// cfp64 = cfp64 * cfp64 + cfp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////////// + +using SM90_16x8x4_C64C64C64C64_TN = SM90::MMA_16x8x4_C64C64C64C64_TN; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +using SM90_16x8x8_C64C64C64C64_TN = SM90::MMA_16x8x8_C64C64C64C64_TN; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +using SM90_16x8x16_C64C64C64C64_TN = SM90::MMA_16x8x16_C64C64C64C64_TN; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e1c3bb40349faecf4ec2f8c36f0e90b7f26d23d7 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -0,0 +1,8987 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // cute::smem_ptr_flag +#include // cute::smem_sparse_ptr_flag +#include // cute::Swizzle +#include // cute::Tensor +#include // cute::LayoutType +#include // cute::SM90_64x8x16_F16F16F16_SS, etc +#include // cute::MMA_Traits +#include // cute::ComposedLayout +#include // cute::is_static + +namespace cute { + +// Fence between the async destination accumulators of GMMA & source for their dependent use +template +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(Tensor& frg) { + CUTE_STATIC_ASSERT(is_static::value); + if constexpr (is_same_v) { + auto f32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(f32_frg); ++i) { + warpgroup_fence_operand(f32_frg(i)); + } + } + else { + CUTE_STATIC_ASSERT(is_rmem::value); + auto u32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(u32_frg); ++i) { + warpgroup_fence_operand(u32_frg(i)); + } + } +} + +namespace SM90::GMMA { + +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// + +// M|N-major GMMA layouts in units of bits +using Layout_MN_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _128>>>; +using Layout_MN_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _256>>>; +using Layout_MN_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _512>>>; +using Layout_MN_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1,_1024>>>; + +// K-major GMMA layouts in units of bits +using Layout_K_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _128,_1>>>; +using Layout_K_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _256,_1>>>; +using Layout_K_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _512,_1>>>; +using Layout_K_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1024,_1>>>; + +// M|N-major layouts in units of Type +template +using Layout_MN_INTER_Atom = decltype(upcast::value>(Layout_MN_INTER_Atom_Bits{})); +template +using Layout_MN_SW32_Atom = decltype(upcast::value>(Layout_MN_SW32_Atom_Bits{})); +template +using Layout_MN_SW64_Atom = decltype(upcast::value>(Layout_MN_SW64_Atom_Bits{})); +template +using Layout_MN_SW128_Atom = decltype(upcast::value>(Layout_MN_SW128_Atom_Bits{})); + +// K-major layouts in units of Type +template +using Layout_K_INTER_Atom = decltype(upcast::value>(Layout_K_INTER_Atom_Bits{})); +template +using Layout_K_SW32_Atom = decltype(upcast::value>(Layout_K_SW32_Atom_Bits{})); +template +using Layout_K_SW64_Atom = decltype(upcast::value>(Layout_K_SW64_Atom_Bits{})); +template +using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_SW128_Atom_Bits{})); + +// With GMMA::Major param +template +using Layout_INTER_Atom = typename conditional, + Layout_K_INTER_Atom>::type; +template +using Layout_SW32_Atom = typename conditional, + Layout_K_SW32_Atom>::type; +template +using Layout_SW64_Atom = typename conditional, + Layout_K_SW64_Atom>::type; +template +using Layout_SW128_Atom = typename conditional, + Layout_K_SW128_Atom>::type; + +// +// Tensor (position-dependent swizzle) to LayoutType utility +// + +template +CUTE_HOST_DEVICE constexpr +LayoutType +layout_type(Tensor> const&) +{ + static_assert(is_same::value, + "Expected uint128_t type in LayoutType conversion."); + + using Swizzle = get_swizzle_t; + constexpr int B = Swizzle::num_bits; + constexpr int M = Swizzle::num_base; + constexpr int S = Swizzle::num_shft; + + static_assert(M == 4, "Unsupported layout swizzle"); + static_assert(0 <= B && B <= 3, "Unsupported layout swizzle"); + static_assert(S == 3, "Unsupported layout swizzle"); + + switch (B) { + case 0: return LayoutType::INTERLEAVE; + case 1: return LayoutType::B32; + case 2: return LayoutType::B64; + case 3: return LayoutType::B128; + } + return LayoutType::INTERLEAVE; // ERROR +} + +/////////////////////////////////////////////////////////////////////////////// +// Construction method for GMMA Descriptors +/////////////////////////////////////////////////////////////////////////////// + +/** +* /////////////////////////////// +* // make_gmma_desc // +* /////////////////////////////// +* Each GmmaDescriptor Major-MN describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO)) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) +* +* where +* T : sizeof(uint128_t) / sizeof(value_type) +* m : integer in [1,16] corresponding to GMMA shape +* k : integer in [1,32] corresponding to GMMA shape +* SBO: stride byte offset +* LBO: leading byte offset +* +* See GMMA::Layout_MN_XXX_Atom for building canonical GmmaDescriptor Major-MN layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. +* +* ////////////////////////////// +* // make_gmma_desc // +* ////////////////////////////// +* Each GmmaDescriptor Major-K describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T )) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) +* +* See GMMA::Layout_K_XXX_Atom for building canonical GmmaDescriptor Major-K layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. +*/ +template +CUTE_HOST_DEVICE constexpr +GmmaDescriptor +make_gmma_desc(Tensor const& tensor) +{ + static_assert(is_smem::value, "GMMA Descriptors can only be constructed on smem."); + static_assert(TLayout::rank == 2, "GMMA Descriptors can only be constructed on rank-2 tensors."); + using value_type = typename TEngine::value_type; + + Tensor u128_tensor = recast(tensor); + + // Result + GmmaDescriptor desc; + + // Layout type + constexpr LayoutType LAYOUT_TYPE = layout_type(u128_tensor); + desc.bitfield.layout_type_ = uint8_t(LAYOUT_TYPE); + + // Start address (4LSB not included) + uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data())); + desc.bitfield.start_address_ = static_cast(start_address >> 4); + + constexpr uint8_t base_offset = 0; + desc.bitfield.base_offset_ = base_offset; + + // LayoutType meta + constexpr int W = LAYOUT_TYPE == LayoutType::INTERLEAVE ? 1 : + LAYOUT_TYPE == LayoutType::B32 ? 2 : + LAYOUT_TYPE == LayoutType::B64 ? 4 : + LAYOUT_TYPE == LayoutType::B128 ? 8 : -1; + + if constexpr (MajorMode == Major::MN) + { + /* In units of uint128_t, each GmmaDescriptor Major-MN describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((1,n),(8,k)):((X,SBO),(1,LBO)) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((2,n),(8,k)):((1,LBO),(2,SBO)) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) + */ + static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{} || // A and B in dense MMA + size<1>(u128_tensor) == Int<(128 / cute::sizeof_bits::value)>{} || // A in sparse MMA + size<1>(u128_tensor) == Int<(512 / cute::sizeof_bits::value)>{}, // B in sparse MMA + "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits for dense or (128|512)/sizeof_bits for sparse."); + + // Construct the canonical GMMA T Layout with shape ((W,n),(8,2)) + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile,_1>,Layout,_1>>{}); + + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical GMMA_MN Layout: Expected profile failure."); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; + static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_MN Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = W; + static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_MN Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); + + desc.bitfield.stride_byte_offset_ = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride_01 : stride_11; + desc.bitfield.leading_byte_offset_ = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride_11 : stride_01; + } + else if constexpr (MajorMode == Major::K) + { + /* In units of uint128_t, each GmmaDescriptor Major-K describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,n),2):((1,SBO),LBO) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,n),2):((2,SBO),1) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,n),2):((4,SBO),1) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),2):((8,SBO),1) + */ + CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size + "Not a canonical GMMA_K Layout: Expected MN-size multiple of 8."); + CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{} || size<1>(u128_tensor) == Int<4>{}, // K size + "Not a canonical GMMA_K Layout: Expected K-size 2 for dense or 4 for sparse (in units of uint128_t)."); + + // Construct the canonical GMMA N Layout with shape ((8,n),(2,1)) + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile,Layout<_2,_1>>{}); + + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical GMMA_K Layout: Expected profile failure."); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = W; + static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_K Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; + static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_K Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + + desc.bitfield.stride_byte_offset_ = stride_01; + desc.bitfield.leading_byte_offset_ = stride_10; + } else { + static_assert(MajorMode != Major::MN && MajorMode != Major::K, "Unrecognized MajorMode!"); + } + return desc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Higher level GMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +struct DescriptorIterator +{ + using reference = GmmaDescriptor; + using element_type = GmmaDescriptor; + using value_type = GmmaDescriptor; + + GmmaDescriptor desc_; + + // Dereference returns the GmmaDescriptor + CUTE_HOST_DEVICE constexpr + reference operator*() const { return desc_; } + + // Advance and return a new GmmaDescriptor + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { return *(*this + i); } + + // Return an advanced iterator + template + CUTE_HOST_DEVICE constexpr + DescriptorIterator operator+(Index const& offset) const + { + // Use 32bit calculation rather than 64 bit calculation as we only update the part of desc + GmmaDescriptor ret; + ret.reg32_[0] = desc_.reg32_[0] + uint32_t(offset); + ret.reg32_[1] = desc_.reg32_[1]; + return { ret }; + } +}; + +template +CUTE_HOST_DEVICE constexpr +GmmaDescriptor +raw_pointer_cast(DescriptorIterator const& ptr) { + return ptr.desc_; +} + +// Recast a DescriptorIterator Tensor to uint64_t, it's RegType in mma_unpack +template +CUTE_HOST_DEVICE constexpr +DescriptorIterator +recast_ptr(DescriptorIterator const& iter) { + static_assert(is_same::value, "Can only cast GmmaDescriptorIterator to uint64_t."); + return iter; // Do nothing, it will still dereference to GmmaDescriptor and decay to uint64_t +} + +CUTE_HOST_DEVICE void +print(DescriptorIterator) { + printf("GMMA::DescriptorIterator"); +} + +// The GMMA Traits below have custom fragment type flags for their smem desc tensors. +// These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. +template +struct smem_desc : DescriptorIterator {}; + +} // end namespace SM90::GMMA + +// Customization point for creating a GMMA::smem_desc Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + return make_tensor(SM90::GMMA::DescriptorIterator{SM90::GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA { + +// +// Specialized mma_unpack implementation for SM90 GMMA instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + // SM90 GMMA take three arguments rather than four, try to assert C and D are aliased + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + // assert((void*)&C == (void*)&D); + + Tensor rA = recast(A); + Tensor rB = recast(B); + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); +} + +// Accumulator layouts +template +using CLayout_64xN = Layout,Shape < _2,_2,Int>>, + Stride,Stride<_64,_8, _512>>>; + +using CLayout_64x8 = CLayout_64xN< 8>; +using CLayout_64x16 = CLayout_64xN< 16>; +using CLayout_64x32 = CLayout_64xN< 32>; +using CLayout_64x64 = CLayout_64xN< 64>; +using CLayout_64x96 = CLayout_64xN< 96>; +using CLayout_64x128 = CLayout_64xN<128>; +using CLayout_64x192 = CLayout_64xN<192>; +using CLayout_64x256 = CLayout_64xN<256>; + +// Register source layout for 32-bit value types +using ALayout_64x8 = Layout,Shape < _2, _2>>, + Stride,Stride< _8,_256>>>; + +// Register source layout for 16-bit (sparse 32-bit) value types +using ALayout_64x16 = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_64,_8,_512>>>; + +// Register source layout for 8-bit (sparse 16-bit) value types +using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_64,_8,_1024>>>; + +// Register source layout for sparse 8-bit value types +using ALayout_64x64 = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_64,_8,_2048>>>; + +// Shared memory source layouts for any value type +template +using ABLayout = Layout,Int>>, + Stride< _0,Stride< _1,Int>>>; + +} // end namespace SM90::GMMA + +using namespace SM90; + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F16F16F16_SS = SM90::GMMA::MMA_64x8x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F16F16F16_RS = SM90::GMMA::MMA_64x8x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F16F16F16_SS = SM90::GMMA::MMA_64x16x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F16F16F16_RS = SM90::GMMA::MMA_64x16x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F16F16F16_SS = SM90::GMMA::MMA_64x32x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F16F16F16_RS = SM90::GMMA::MMA_64x32x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F16F16F16_SS = SM90::GMMA::MMA_64x64x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F16F16F16_RS = SM90::GMMA::MMA_64x64x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F16F16F16_SS = SM90::GMMA::MMA_64x96x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F16F16F16_RS = SM90::GMMA::MMA_64x96x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F16F16F16_SS = SM90::GMMA::MMA_64x128x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F16F16F16_RS = SM90::GMMA::MMA_64x128x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F16F16F16_SS = SM90::GMMA::MMA_64x192x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F16F16F16_RS = SM90::GMMA::MMA_64x192x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F16F16F16_SS = SM90::GMMA::MMA_64x256x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F16F16F16_RS = SM90::GMMA::MMA_64x256x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32F16F16_SS = SM90::GMMA::MMA_64x8x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32F16F16_RS = SM90::GMMA::MMA_64x8x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32F16F16_SS = SM90::GMMA::MMA_64x16x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32F16F16_RS = SM90::GMMA::MMA_64x16x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32F16F16_SS = SM90::GMMA::MMA_64x32x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32F16F16_RS = SM90::GMMA::MMA_64x32x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32F16F16_SS = SM90::GMMA::MMA_64x64x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32F16F16_RS = SM90::GMMA::MMA_64x64x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32F16F16_SS = SM90::GMMA::MMA_64x96x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32F16F16_RS = SM90::GMMA::MMA_64x96x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32F16F16_SS = SM90::GMMA::MMA_64x128x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32F16F16_RS = SM90::GMMA::MMA_64x128x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32F16F16_SS = SM90::GMMA::MMA_64x192x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32F16F16_RS = SM90::GMMA::MMA_64x192x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32F16F16_SS = SM90::GMMA::MMA_64x256x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32F16F16_RS = SM90::GMMA::MMA_64x256x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_traits_sm90_gmma_ext.hpp" +#endif diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_ext.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_ext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3cab34d257710e2b335e526558c1401826b27492 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_ext.hpp @@ -0,0 +1,20116 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +namespace cute { + +namespace SM90::GMMA { + +using CLayout_64x24 = CLayout_64xN< 24>; +using CLayout_64x40 = CLayout_64xN< 40>; +using CLayout_64x48 = CLayout_64xN< 48>; +using CLayout_64x56 = CLayout_64xN< 56>; +using CLayout_64x72 = CLayout_64xN< 72>; +using CLayout_64x80 = CLayout_64xN< 80>; +using CLayout_64x88 = CLayout_64xN< 88>; +using CLayout_64x104 = CLayout_64xN<104>; +using CLayout_64x112 = CLayout_64xN<112>; +using CLayout_64x120 = CLayout_64xN<120>; +using CLayout_64x136 = CLayout_64xN<136>; +using CLayout_64x144 = CLayout_64xN<144>; +using CLayout_64x152 = CLayout_64xN<152>; +using CLayout_64x160 = CLayout_64xN<160>; +using CLayout_64x168 = CLayout_64xN<168>; +using CLayout_64x176 = CLayout_64xN<176>; +using CLayout_64x184 = CLayout_64xN<184>; +using CLayout_64x200 = CLayout_64xN<200>; +using CLayout_64x208 = CLayout_64xN<208>; +using CLayout_64x216 = CLayout_64xN<216>; +using CLayout_64x224 = CLayout_64xN<224>; +using CLayout_64x232 = CLayout_64xN<232>; +using CLayout_64x240 = CLayout_64xN<240>; +using CLayout_64x248 = CLayout_64xN<248>; + +} + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F16F16F16_SS = SM90::GMMA::MMA_64x24x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F16F16F16_RS = SM90::GMMA::MMA_64x24x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F16F16F16_SS = SM90::GMMA::MMA_64x40x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F16F16F16_RS = SM90::GMMA::MMA_64x40x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_SS = SM90::GMMA::MMA_64x48x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_RS = SM90::GMMA::MMA_64x48x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F16F16F16_SS = SM90::GMMA::MMA_64x56x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F16F16F16_RS = SM90::GMMA::MMA_64x56x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F16F16F16_SS = SM90::GMMA::MMA_64x72x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F16F16F16_RS = SM90::GMMA::MMA_64x72x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_SS = SM90::GMMA::MMA_64x80x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_RS = SM90::GMMA::MMA_64x80x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F16F16F16_SS = SM90::GMMA::MMA_64x88x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F16F16F16_RS = SM90::GMMA::MMA_64x88x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F16F16F16_SS = SM90::GMMA::MMA_64x104x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F16F16F16_RS = SM90::GMMA::MMA_64x104x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_SS = SM90::GMMA::MMA_64x112x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_RS = SM90::GMMA::MMA_64x112x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F16F16F16_SS = SM90::GMMA::MMA_64x120x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F16F16F16_RS = SM90::GMMA::MMA_64x120x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F16F16F16_SS = SM90::GMMA::MMA_64x136x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F16F16F16_RS = SM90::GMMA::MMA_64x136x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_SS = SM90::GMMA::MMA_64x144x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_RS = SM90::GMMA::MMA_64x144x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F16F16F16_SS = SM90::GMMA::MMA_64x152x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F16F16F16_RS = SM90::GMMA::MMA_64x152x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_SS = SM90::GMMA::MMA_64x160x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_RS = SM90::GMMA::MMA_64x160x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F16F16F16_SS = SM90::GMMA::MMA_64x168x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F16F16F16_RS = SM90::GMMA::MMA_64x168x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_SS = SM90::GMMA::MMA_64x176x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_RS = SM90::GMMA::MMA_64x176x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F16F16F16_SS = SM90::GMMA::MMA_64x184x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F16F16F16_RS = SM90::GMMA::MMA_64x184x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F16F16F16_SS = SM90::GMMA::MMA_64x200x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F16F16F16_RS = SM90::GMMA::MMA_64x200x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_SS = SM90::GMMA::MMA_64x208x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_RS = SM90::GMMA::MMA_64x208x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F16F16F16_SS = SM90::GMMA::MMA_64x216x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F16F16F16_RS = SM90::GMMA::MMA_64x216x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_SS = SM90::GMMA::MMA_64x224x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_RS = SM90::GMMA::MMA_64x224x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F16F16F16_SS = SM90::GMMA::MMA_64x232x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F16F16F16_RS = SM90::GMMA::MMA_64x232x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_SS = SM90::GMMA::MMA_64x240x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_RS = SM90::GMMA::MMA_64x240x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F16F16F16_SS = SM90::GMMA::MMA_64x248x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F16F16F16_RS = SM90::GMMA::MMA_64x248x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32F16F16_SS = SM90::GMMA::MMA_64x24x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32F16F16_RS = SM90::GMMA::MMA_64x24x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32F16F16_RS = SM90::GMMA::MMA_64x40x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_SS = SM90::GMMA::MMA_64x48x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_RS = SM90::GMMA::MMA_64x48x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32F16F16_SS = SM90::GMMA::MMA_64x56x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32F16F16_RS = SM90::GMMA::MMA_64x56x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32F16F16_SS = SM90::GMMA::MMA_64x72x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32F16F16_RS = SM90::GMMA::MMA_64x72x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_SS = SM90::GMMA::MMA_64x80x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_RS = SM90::GMMA::MMA_64x80x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32F16F16_SS = SM90::GMMA::MMA_64x88x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32F16F16_RS = SM90::GMMA::MMA_64x88x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32F16F16_SS = SM90::GMMA::MMA_64x104x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32F16F16_RS = SM90::GMMA::MMA_64x104x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_SS = SM90::GMMA::MMA_64x112x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_RS = SM90::GMMA::MMA_64x112x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32F16F16_SS = SM90::GMMA::MMA_64x120x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32F16F16_RS = SM90::GMMA::MMA_64x120x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32F16F16_SS = SM90::GMMA::MMA_64x136x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32F16F16_RS = SM90::GMMA::MMA_64x136x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_SS = SM90::GMMA::MMA_64x144x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_RS = SM90::GMMA::MMA_64x144x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32F16F16_SS = SM90::GMMA::MMA_64x152x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32F16F16_RS = SM90::GMMA::MMA_64x152x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_SS = SM90::GMMA::MMA_64x160x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_RS = SM90::GMMA::MMA_64x160x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32F16F16_SS = SM90::GMMA::MMA_64x168x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32F16F16_RS = SM90::GMMA::MMA_64x168x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_SS = SM90::GMMA::MMA_64x176x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_RS = SM90::GMMA::MMA_64x176x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32F16F16_SS = SM90::GMMA::MMA_64x184x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32F16F16_RS = SM90::GMMA::MMA_64x184x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32F16F16_SS = SM90::GMMA::MMA_64x200x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32F16F16_RS = SM90::GMMA::MMA_64x200x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_SS = SM90::GMMA::MMA_64x208x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_RS = SM90::GMMA::MMA_64x208x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32F16F16_SS = SM90::GMMA::MMA_64x216x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32F16F16_RS = SM90::GMMA::MMA_64x216x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_SS = SM90::GMMA::MMA_64x224x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_RS = SM90::GMMA::MMA_64x224x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32F16F16_SS = SM90::GMMA::MMA_64x232x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32F16F16_RS = SM90::GMMA::MMA_64x232x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_SS = SM90::GMMA::MMA_64x240x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_RS = SM90::GMMA::MMA_64x240x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32F16F16_SS = SM90::GMMA::MMA_64x248x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32F16F16_RS = SM90::GMMA::MMA_64x248x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x24x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x24x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x56x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x56x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x72x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x72x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x88x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x88x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x104x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x104x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x120x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x120x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x136x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x136x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x152x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x152x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x168x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x168x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x184x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x184x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x200x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x200x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x216x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x216x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x232x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x232x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x248x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x248x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x24x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 24, 8>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x24x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 24, 8>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x40x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 40, 8>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x40x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 40, 8>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 48, 8>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 48, 8>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x56x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 56, 8>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x56x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 56, 8>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x72x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 72, 8>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x72x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 72, 8>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 80, 8>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 80, 8>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x88x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 88, 8>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x88x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 88, 8>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x104x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<104, 8>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x104x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<104, 8>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<112, 8>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<112, 8>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x120x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<120, 8>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x120x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<120, 8>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x136x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<136, 8>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x136x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<136, 8>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<144, 8>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<144, 8>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x152x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<152, 8>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x152x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<152, 8>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<160, 8>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<160, 8>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x168x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<168, 8>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x168x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<168, 8>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<176, 8>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<176, 8>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x184x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<184, 8>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x184x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<184, 8>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x200x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<200, 8>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x200x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<200, 8>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<208, 8>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<208, 8>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x216x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<216, 8>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x216x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<216, 8>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<224, 8>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<224, 8>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x232x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<232, 8>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x232x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<232, 8>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<240, 8>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<240, 8>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x248x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<248, 8>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x248x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<248, 8>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..13ff07c89f47a15bb65f57be2ad999df1f6568c9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp @@ -0,0 +1,7738 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // cute::smem_sparse_ptr_flag +#include // cute::Swizzle +#include // cute::Tensor +#include // cute::LayoutType +#include // cute::SM90::SPARSE::GMMA_64x8x32_F16F16F16_SS, etc +#include // cute::GMMA::Layout_* +#include // cute::MMA_Traits +#include // cute::ComposedLayout +#include // cute::is_static + +namespace cute { + +namespace SM90::GMMA { + +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// + +// M|N-major layouts in units of Type and sparsity factor S +template +using Layout_MN_INTER_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_INTER_Atom{}.layout_b()))>; +template +using Layout_MN_SW32_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW32_Atom{}.layout_b()))>; +template +using Layout_MN_SW64_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW64_Atom{}.layout_b()))>; +template +using Layout_MN_SW128_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW128_Atom{}.layout_b()))>; + +// K-major layouts in units of Type and sparsity factor S +template +using Layout_K_INTER_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_INTER_Atom{}.layout_b()))>; +template +using Layout_K_SW32_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW32_Atom{}.layout_b()))>; +template +using Layout_K_SW64_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW64_Atom{}.layout_b()))>; +template +using Layout_K_SW128_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW128_Atom{}.layout_b()))>; + +// With GMMA::Major param +template +using Layout_INTER_SpAtom = typename conditional, + Layout_K_INTER_SpAtom>::type; +template +using Layout_SW32_SpAtom = typename conditional, + Layout_K_SW32_SpAtom>::type; +template +using Layout_SW64_SpAtom = typename conditional, + Layout_K_SW64_SpAtom>::type; +template +using Layout_SW128_SpAtom = typename conditional, + Layout_K_SW128_SpAtom>::type; + +/////////////////////////////////////////////////////////////////////////////// +// Higher level GMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +template +struct sparse_smem_desc : DescriptorIterator {}; + +} // end namespace SM90::GMMA + +// Customization point for creating a cute::GMMAsparse_smem_desc Tensor +template +struct MakeTensor> +{ + // Note that this is the exact same as cute::GMMAsmem_desc above, plus additional static checks. + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + static_assert(is_sparse::value, "Expected sparse value_type."); + static_assert(is_sparse_ptr::value, "Expected sparse iter."); + return make_tensor(SM90::GMMA::DescriptorIterator{SM90::GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA { + +// Metadata layouts +using ELayout_64x64 = Layout, Shape <_32>>, + Stride, Stride<_64>>>; + +using ELayout_64x32 = Layout, Shape <_16,_2>>, + Stride, Stride<_64,_8>>>; + +using ELayout_64x16 = Layout, Shape < _8,_2>>, + Stride, Stride<_64,_8>>>; + +} // namespace SM90::GMMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A_zipped, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + auto [A, E] = unzip_tensor(A_zipped); + Tensor rA = recast(A); + Tensor rE = recast(E); + Tensor rB = recast(B); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + static_assert(is_same::value, "GMMA DRegisters must have void type."); + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMAOp::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::SPARSE + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_traits_sm90_gmma_sparse_ext.hpp" +#endif diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fc28b8a97a77e5e082effd1cc994f25465b0a3e6 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp @@ -0,0 +1,17335 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +namespace cute { + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/partitioner.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/partitioner.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b15de6e266e206e37bc29fcdcfbbc2e74decce16 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/atom/partitioner.hpp @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(type_traits) +#else +#include +#endif + +#include +#include + +namespace cute { + +// +// A generic tiling of thread-value layouts +// + +template coord [Need not be 2D...] + class Tiler_MN_> // coord space +struct TV_Tiler +{ + using Tiler_MN = Tiler_MN_; + using TiledLayout_TV = Layout_TV_; + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,FrgV),(RestM,RestN,...)) + // where + // ThrV: The threads local to a tile. + // FrgV: The values local to a tile. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + apply(Tensor&& tensor) + { + // If Layout_TV and Tiler_MN were composable in general, then this won't be needed! + + // ((thr_id,val_id),(RestM,RestN,...)) + return zipped_divide(tensor, Tiler_MN{}).compose(TiledLayout_TV{}, _); + } + + template + struct TV_Partitioner + { + SliceCoord coord_; + + template + CUTE_HOST_DEVICE + auto + partition(TargetTensor&& target) { + Tensor thr_tensor = make_tensor(static_cast(target).data(), apply(target.layout())); + return thr_tensor(coord_, repeat>(_)); + } + }; + + template + CUTE_HOST_DEVICE static + auto + get_slice(SliceCoord const& coord) + { + return TV_Partitioner{coord}; + } +}; + +template +CUTE_HOST_DEVICE +auto +make_tiler_impl(Layout_TV const&, + Tiler_MN const&) +{ + return TV_Tiler{}; +} + +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/config.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..538472c274baa5ccc7c59ac3d36b52791f061fdf --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/config.hpp @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) +# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ +# define CUTE_DEVICE __forceinline__ __device__ +# define CUTE_HOST __forceinline__ __host__ +#else +# define CUTE_HOST_DEVICE inline +# define CUTE_DEVICE inline +# define CUTE_HOST inline +#endif // CUTE_HOST_DEVICE, CUTE_DEVICE + +#if defined(__CUDACC_RTC__) +# define CUTE_HOST_RTC CUTE_HOST_DEVICE +#else +# define CUTE_HOST_RTC CUTE_HOST +#endif + +#if !defined(__CUDACC_RTC__) && !defined(__clang__) && \ + (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) +# define CUTE_UNROLL #pragma unroll +# define CUTE_NO_UNROLL #pragma unroll 1 +#elif defined(__CUDACC_RTC__) || defined(__clang__) +# define CUTE_UNROLL _Pragma("unroll") +# define CUTE_NO_UNROLL _Pragma("unroll 1") +#else +# define CUTE_UNROLL +# define CUTE_NO_UNROLL +#endif // CUTE_UNROLL + +#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) +# define CUTE_INLINE_CONSTANT static const __device__ +#else +# define CUTE_INLINE_CONSTANT static constexpr +#endif + +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +# define CUTE_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTE_GRID_CONSTANT) +# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) +# define CUTE_GRID_CONSTANT __grid_constant__ +# else +# define CUTE_GRID_CONSTANT +# endif +#endif + +// Some versions of GCC < 11 have trouble deducing that a +// function with "auto" return type and all of its returns in an "if +// constexpr ... else" statement must actually return. Thus, GCC +// emits spurious "missing return statement" build warnings. +// Developers can suppress these warnings by using the +// CUTE_GCC_UNREACHABLE macro, which must be followed by a semicolon. +// It's harmless to use the macro for other GCC versions or other +// compilers, but it has no effect. +#if ! defined(CUTE_GCC_UNREACHABLE) +# if defined(__GNUC__) +# define CUTE_GCC_UNREACHABLE __builtin_unreachable() +# else +# define CUTE_GCC_UNREACHABLE +# endif +#endif + +#if defined(_MSC_VER) +// Provides support for alternative operators 'and', 'or', and 'not' +# include +#endif // _MSC_VER + +#if defined(__CUDACC_RTC__) +# define CUTE_STL_NAMESPACE cuda::std +# define CUTE_STL_NAMESPACE_IS_CUDA_STD +#else +# define CUTE_STL_NAMESPACE std +#endif + +// +// Assertion helpers +// + +#if defined(__CUDACC_RTC__) +# include +#else +# include +#endif + +#define CUTE_STATIC_V(x) decltype(x)::value + +#define CUTE_STATIC_ASSERT static_assert +#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__) + +// Fail and print a message. Typically used for notification of a compiler misconfiguration. +#if defined(__CUDA_ARCH__) +# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x); __brkpt() +#else +# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x) +#endif + +// +// IO +// + +#if !defined(__CUDACC_RTC__) +# include +# include +# include +#endif + +// +// Type +// + +#if defined(__CUDACC_RTC__) +# include +#else +# include +#endif + +// +// Debugging utilities +// + +#include diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/alignment.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/alignment.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f285004c0c55c9cedc67dbfa1c83b202b3c10f41 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/alignment.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// Test if a pointer is aligned to N bytes +template +CUTE_HOST_DEVICE constexpr +bool +is_byte_aligned(void const* const ptr) +{ + static_assert(has_single_bit(N), "N must be a power of 2 in alignment check"); + return (reinterpret_cast(ptr) & (N-1)) == 0; +} + +#if defined(__CUDACC__) +# define CUTE_ALIGNAS(n) __align__(n) +#else +# define CUTE_ALIGNAS(n) alignas(n) +#endif + +template +struct aligned_struct {}; + +template struct CUTE_ALIGNAS( 1) aligned_struct< 1, Child> {}; +template struct CUTE_ALIGNAS( 2) aligned_struct< 2, Child> {}; +template struct CUTE_ALIGNAS( 4) aligned_struct< 4, Child> {}; +template struct CUTE_ALIGNAS( 8) aligned_struct< 8, Child> {}; +template struct CUTE_ALIGNAS( 16) aligned_struct< 16, Child> {}; +template struct CUTE_ALIGNAS( 32) aligned_struct< 32, Child> {}; +template struct CUTE_ALIGNAS( 64) aligned_struct< 64, Child> {}; +template struct CUTE_ALIGNAS(128) aligned_struct<128, Child> {}; +template struct CUTE_ALIGNAS(256) aligned_struct<256, Child> {}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/array.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..31606e45b9c1a6520596ccdb778f6a962b561179 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/array.hpp @@ -0,0 +1,470 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template +struct array +{ + using element_type = T; + using value_type = remove_cv_t; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = element_type&; + using const_reference = const element_type&; + using pointer = element_type*; + using const_pointer = const element_type*; + using iterator = pointer; + using const_iterator = const_pointer; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return __elems_; + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return __elems_; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return size() == 0; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return size(); + } + + CUTE_HOST_DEVICE constexpr + void fill(const T& value) + { + for (auto& e : *this) { + e = value; + } + } + + CUTE_HOST_DEVICE constexpr + void clear() + { + fill(T(0)); + } + + CUTE_HOST_DEVICE constexpr + void swap(array& other) + { + using CUTE_STL_NAMESPACE::swap; + for (size_type i = 0; i < size(); ++i) { + swap((*this)[i], other[i]); + } + } + + element_type __elems_[N]; +}; + + +template +struct array +{ + using element_type = T; + using value_type = remove_cv_t; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = element_type&; + using const_reference = const element_type&; + using pointer = element_type*; + using const_pointer = const element_type*; + using const_iterator = const_pointer; + using iterator = pointer; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return true; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return 0; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return 0; + } + + CUTE_HOST_DEVICE constexpr + void fill(const T& value) + {} + + CUTE_HOST_DEVICE constexpr + void clear() + {} + + CUTE_HOST_DEVICE constexpr + void swap(array& other) + {} +}; + +template +CUTE_HOST_DEVICE constexpr +bool operator==(array const& lhs, array const& rhs) +{ + for (size_t i = 0; i < N; ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; +} + +template +CUTE_HOST_DEVICE constexpr +void clear(array& a) +{ + a.fill(T(0)); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array& a, T const& value) +{ + a.fill(value); +} + +template +CUTE_HOST_DEVICE constexpr +void swap(array& a, array& b) +{ + a.swap(b); +} + +/// @return A cute::array of the elements of @c t in reverse order. +template +CUTE_HOST_DEVICE constexpr +cute::array reverse(cute::array const& t) +{ + if constexpr (N == 0u) { + return t; + } else { + cute::array t_r{}; + for (size_t k = 0; k < N; ++k) { + t_r[k] = t[N - k - 1]; + } + return t_r; + } +} + +} // end cute + + +// +// Specialize tuple-related functionality for cute::array +// +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(tuple) +#else +#include +#endif + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array&& a) +{ + static_assert(I < N, "Index out of range"); + return cute::move(a[I]); +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/array_aligned.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/array_aligned.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6491f72db1b8a98c7d8e445ab53cb9ec10ee4640 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/array_aligned.hpp @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_ALIGNAS +#include // cute::array + +namespace cute +{ + +template +struct CUTE_ALIGNAS(Alignment) array_aligned : cute::array {}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/array_subbyte.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/array_subbyte.hpp new file mode 100644 index 0000000000000000000000000000000000000000..be6410b318be72f74534b4344c3810cd1ccfb389 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/array_subbyte.hpp @@ -0,0 +1,640 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates subbyte trivial types + in a packed storage. +*/ + +#pragma once + +#include + +#include +#include + +namespace cute +{ +// +// Underlying subbyte storage type +// +template +using subbyte_storage_type_t = conditional_t<(cute::sizeof_bits_v <= 8), uint8_t, + conditional_t<(cute::sizeof_bits_v <= 16), uint16_t, + conditional_t<(cute::sizeof_bits_v <= 32), uint32_t, + conditional_t<(cute::sizeof_bits_v <= 64), uint64_t, + conditional_t<(cute::sizeof_bits_v <= 128), uint128_t, + T>>>>>; + +template struct subbyte_iterator; +template struct swizzle_ptr; + +// +// subbyte_reference +// Proxy object for sub-byte element references +// +template +struct subbyte_reference +{ + // Iterator Element type (const or non-const) + using element_type = T; + // Iterator Value type without type qualifier. + using value_type = remove_cv_t; + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + + static_assert(sizeof_bits_v <= sizeof_bits_v, + "Size of Element must not be greater than Storage."); + +private: + + // Bitmask for covering one item + static constexpr storage_type BitMask = storage_type(storage_type(-1) >> (sizeof_bits_v - sizeof_bits_v)); + // Flag for fast branching on straddled elements + static constexpr bool is_storage_unaligned = ((sizeof_bits_v % sizeof_bits_v) != 0); + + friend struct subbyte_iterator; + + // Pointer to storage element + storage_type* ptr_ = nullptr; + + // Bit index of value_type starting position within storage_type element. + // RI: 0 <= idx_ < sizeof_bit + uint8_t idx_ = 0; + + // Ctor + template + CUTE_HOST_DEVICE constexpr + subbyte_reference(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast(ptr)), idx_(idx) {} + +public: + + // Copy Ctor + CUTE_HOST_DEVICE constexpr + subbyte_reference(subbyte_reference const& other) { + *this = other.get(); + } + + CUTE_HOST_DEVICE constexpr + subbyte_reference(subbyte_reference const& other) { + *this = other.get(); + } + + // Copy Assignment + CUTE_HOST_DEVICE constexpr + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = other.get(); + } + + CUTE_HOST_DEVICE constexpr + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = other.get(); + } + + // Assignment + template + CUTE_HOST_DEVICE constexpr + enable_if_t, subbyte_reference&> operator=(value_type x) + { + static_assert(is_same_v, "Do not specify template arguments!"); + storage_type item = (reinterpret_cast(x) & BitMask); + + // Update the current storage element + storage_type bit_mask_0 = storage_type(BitMask << idx_); + ptr_[0] = storage_type((ptr_[0] & ~bit_mask_0) | (item << idx_)); + + // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) + if (is_storage_unaligned && idx_ + sizeof_bits_v > sizeof_bits_v) { + uint8_t straddle_bits = uint8_t(sizeof_bits_v - idx_); + storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); + // Update the next storage element + ptr_[1] = storage_type((ptr_[1] & ~bit_mask_1) | (item >> straddle_bits)); + } + + return *this; + } + + // Comparison of referenced values + CUTE_HOST_DEVICE constexpr friend + bool operator==(subbyte_reference const& x, subbyte_reference const& y) { return x.get() == y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() != y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator< (subbyte_reference const& x, subbyte_reference const& y) { return x.get() < y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator> (subbyte_reference const& x, subbyte_reference const& y) { return x.get() > y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator<=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() <= y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator>=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() >= y.get(); } + + // Value + CUTE_HOST_DEVICE + value_type get() const + { + if constexpr (is_same_v) { // Extract to bool -- potentially faster impl + return bool((*ptr_) & (BitMask << idx_)); + } else { // Extract to value_type + // Extract from the current storage element + auto item = storage_type((ptr_[0] >> idx_) & BitMask); + + // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) + if (is_storage_unaligned && idx_ + sizeof_bits_v > sizeof_bits_v) { + uint8_t straddle_bits = uint8_t(sizeof_bits_v - idx_); + storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); + // Extract from the next storage element + item |= storage_type((ptr_[1] & bit_mask_1) << straddle_bits); + } + + return reinterpret_cast(item); + } + } + + // Extract to type value_type + CUTE_HOST_DEVICE constexpr + operator value_type() const { + return get(); + } + + // Address + CUTE_HOST_DEVICE + subbyte_iterator operator&() const { + return {ptr_, idx_}; + } +}; + +template +CUTE_HOST_DEVICE +void +print(subbyte_reference ref) { + cute::print(ref.get()); +} + +template +CUTE_HOST_DEVICE +void +pretty_print(subbyte_reference ref) { + cute::pretty_print(ref.get()); +} + +// +// subbyte_iterator +// Random-access iterator over subbyte references +// +template +struct subbyte_iterator +{ + // Iterator Element type (const or non-const) + using element_type = T; + // Iterator Value type without type qualifier. + using value_type = remove_cv_t; + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + // Reference proxy type + using reference = subbyte_reference; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + + static_assert(sizeof_bits_v <= sizeof_bits_v, + "Size of Element must not be greater than Storage."); + +private: + + template friend struct swizzle_ptr; + template friend CUTE_HOST_DEVICE constexpr U* raw_pointer_cast(subbyte_iterator const&); + template friend CUTE_HOST_DEVICE constexpr auto recast_ptr(subbyte_iterator const&); + template friend CUTE_HOST_DEVICE void print(subbyte_iterator const&); + + // Pointer to storage element + storage_type* ptr_; + + // Bit index of value_type starting position within storage_type element. + // RI: 0 <= idx_ < sizeof_bit + uint8_t idx_; + +public: + + // Default Ctor + CUTE_HOST_DEVICE constexpr + subbyte_iterator() : ptr_{nullptr}, idx_{0} {}; + + // Ctor + template + CUTE_HOST_DEVICE constexpr + subbyte_iterator(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast(ptr)), idx_(idx) { } + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return reference(ptr_, idx_); + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator+=(uint64_t k) { + k = sizeof_bits_v * k + idx_; + ptr_ += k / sizeof_bits_v; + idx_ = k % sizeof_bits_v; + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator+(uint64_t k) const { + return subbyte_iterator(ptr_, idx_) += k; + } + + CUTE_HOST_DEVICE constexpr + reference operator[](uint64_t k) const { + return *(*this + k); + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator++() { + idx_ += sizeof_bits_v; + if (idx_ >= sizeof_bits_v) { + ++ptr_; + idx_ -= sizeof_bits_v; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator++(int) { + subbyte_iterator ret(*this); + ++(*this); + return ret; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator--() { + if (idx_ >= sizeof_bits_v) { + idx_ -= sizeof_bits_v; + } else { + --ptr_; + idx_ += sizeof_bits_v - sizeof_bits_v; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator--(int) { + subbyte_iterator ret(*this); + --(*this); + return ret; + } + + CUTE_HOST_DEVICE constexpr friend + bool operator==(subbyte_iterator const& x, subbyte_iterator const& y) { + return x.ptr_ == y.ptr_ && x.idx_ == y.idx_; + } + CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x == y); } + CUTE_HOST_DEVICE constexpr friend + bool operator< (subbyte_iterator const& x, subbyte_iterator const& y) { + return x.ptr_ < y.ptr_ || (x.ptr_ == y.ptr_ && x.idx_ < y.idx_); + } + CUTE_HOST_DEVICE constexpr friend + bool operator<=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(y < x); } + CUTE_HOST_DEVICE constexpr friend + bool operator> (subbyte_iterator const& x, subbyte_iterator const& y) { return (y < x); } + CUTE_HOST_DEVICE constexpr friend + bool operator>=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x < y); } +}; + +// Conversion to raw pointer with loss of subbyte index +template +CUTE_HOST_DEVICE constexpr +T* +raw_pointer_cast(subbyte_iterator const& x) { + assert(x.idx_ == 0); + return reinterpret_cast(x.ptr_); +} + +// Conversion to NewT_ with possible loss of subbyte index +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(subbyte_iterator const& x) { + using NewT = conditional_t<(is_const_v), NewT_ const, NewT_>; + if constexpr (cute::is_subbyte_v) { // Making subbyte_iter, preserve the subbyte idx + return subbyte_iterator(x.ptr_, x.idx_); + } else { // Not subbyte, assume/assert subbyte idx 0 + return reinterpret_cast(raw_pointer_cast(x)); + } + CUTE_GCC_UNREACHABLE; +} + +// Dynamic pointers have unknown static alignment +template +CUTE_HOST_DEVICE constexpr +Int<0> +max_alignment(subbyte_iterator const& x) { + return {}; +} + +template +CUTE_HOST_DEVICE void +print(subbyte_iterator const& x) { + printf("subptr[%db](%p.%u)", int(sizeof_bits_v), x.ptr_, x.idx_); +} + +template +CUTE_HOST_DEVICE void +print(subbyte_reference const& x) { + print(x.get()); +} + +// +// array_subbyte +// Statically sized array for non-byte-aligned data types +// +template +struct array_subbyte +{ + using element_type = T; + using value_type = remove_cv_t; + using pointer = element_type*; + using const_pointer = element_type const*; + + using size_type = size_t; + using difference_type = ptrdiff_t; + + // + // References + // + using reference = subbyte_reference; + using const_reference = subbyte_reference; + + // + // Iterators + // + using iterator = subbyte_iterator; + using const_iterator = subbyte_iterator; + + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + +private: + + // Number of storage elements, ceil_div + static constexpr size_type StorageElements = (N * sizeof_bits_v + sizeof_bits_v - 1) / sizeof_bits_v; + + // Internal storage + storage_type storage[StorageElements]; + +public: + + CUTE_HOST_DEVICE constexpr + size_type size() const { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const { + return N; + } + + CUTE_HOST_DEVICE constexpr + bool empty() const { + return !N; + } + + // Efficient clear method + CUTE_HOST_DEVICE constexpr + void clear() { + CUTE_UNROLL + for (size_type i = 0; i < StorageElements; ++i) { + storage[i] = storage_type(0); + } + } + + CUTE_HOST_DEVICE constexpr + void fill(T const& value) { + CUTE_UNROLL + for (size_type i = 0; i < N; ++i) { + at(i) = value; + } + } + + CUTE_HOST_DEVICE constexpr + reference at(size_type pos) { + return iterator(storage)[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference at(size_type pos) const { + return const_iterator(storage)[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) { + return at(pos); + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const { + return at(pos); + } + + CUTE_HOST_DEVICE constexpr + reference front() { + return at(0); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const { + return at(0); + } + + CUTE_HOST_DEVICE constexpr + reference back() { + return at(N-1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const { + return at(N-1); + } + + // In analogy to std::vector::data(), these functions are deleted to prevent bugs. + // Instead, prefer + // auto* data = raw_pointer_cast(my_subbyte_array.begin()); + // where the type of auto* is implementation-defined and + // with the knowledge that [data, data + my_subbyte_array.size()) may not be a valid range. + CUTE_HOST_DEVICE constexpr + pointer data() = delete; + + CUTE_HOST_DEVICE constexpr + const_pointer data() const = delete; + + CUTE_HOST_DEVICE constexpr + iterator begin() { + return iterator(storage); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const { + return const_iterator(storage); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() { + return iterator(storage) + N; + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const { + return const_iterator(storage) + N; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const { + return end(); + } + + // + // Comparison operators + // + +}; + +// +// Operators +// + +template +CUTE_HOST_DEVICE constexpr +void clear(array_subbyte& a) +{ + a.clear(); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array_subbyte& a, T const& value) +{ + a.fill(value); +} + +} // namespace cute + +// +// Specialize tuple-related functionality for cute::array_subbyte +// +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(tuple) +#else +#include +#endif + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array_subbyte& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array_subbyte const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array_subbyte&& a) +{ + static_assert(I < N, "Index out of range"); + return cute::move(a[I]); +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct is_reference> + : CUTE_STL_NAMESPACE::true_type +{}; + + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/bit_field.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/bit_field.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cecdaeed35c927084f69417324908abae19148fa --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/bit_field.hpp @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Portable bit field that supports byte and word straddling that can + be used in unions to bit-wise define parameters. +*/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // uint_bit_t +#include // cute::is_same + +namespace cute +{ + +class dummy_type {}; + +template +struct bit_field +{ + static_assert(0 < NumBits && NumBits <= 64, "bit_fields with more than 64 bits are not supported."); + + // value_type: Use the smallest value type that fits NumBits + static constexpr uint32_t value_type_bits = (NumBits <= 8) ? 8 : + (NumBits <= 16) ? 16 : + (NumBits <= 32) ? 32 : 64; + using value_type = cute::uint_bit_t; + // storage_type: Use the smallest storage_type that avoids boundary crossing + static constexpr uint32_t storage_type_bits = (BitStart / 8 == (BitStart + NumBits - 1) / 8) ? 8 : + (BitStart / 16 == (BitStart + NumBits - 1) / 16) ? 16 : + (BitStart / 32 == (BitStart + NumBits - 1) / 32) ? 32 : 64; + using storage_type = cute::uint_bit_t; + + static_assert(sizeof(OtherValueType) == sizeof(value_type) || is_same::value, + "sizeof(OtherValueType) must be same as sizeof(value_type)."); + + // Number of storage values needed: ceil_div(BitStart + NumBits, storage_type_bits) + static constexpr uint32_t N = (BitStart + NumBits + storage_type_bits - 1) / storage_type_bits; + // Index of storage value for BitStart + static constexpr uint32_t idx = BitStart / storage_type_bits; + // Bit of data_[idx] for BitStart + static constexpr uint32_t bit_lo = BitStart % storage_type_bits; + // Number of bits in data_[idx] used for NumBits if straddling, else 0 + static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0; + +public: + + // NumBits mask + static constexpr value_type mask = value_type(uint64_t(-1) >> (64u - NumBits)); + // NumBits mask for BitStart + static constexpr storage_type mask_lo = storage_type(mask) << bit_lo; + // NumBits mask for leftover bits in data_[idx+1] if straddling, else 0 + static constexpr storage_type mask_hi = (idx + 1 < N) ? (storage_type(mask) >> bit_hi) : 0; + + storage_type data_[N]; + + // Get value + CUTE_HOST_DEVICE constexpr + value_type get() const { + storage_type result = (data_[idx] & mask_lo) >> bit_lo; + if constexpr (bit_hi != 0) { + result |= (data_[idx+1] & mask_hi) << bit_hi; + } + return static_cast(result); + } + + // Set value + CUTE_HOST_DEVICE constexpr + void set(value_type x) { + storage_type item = static_cast(x & mask); + data_[idx] = static_cast((data_[idx] & ~mask_lo) | (item << bit_lo)); + if constexpr (bit_hi != 0) { + data_[idx+1] = static_cast((data_[idx+1] & ~mask_hi) | (item >> bit_hi)); + } + } + + // Assign value + CUTE_HOST_DEVICE constexpr + bit_field& operator=(value_type x) { + set(x); + return *this; + } + + // Cast to value + CUTE_HOST_DEVICE constexpr + operator value_type () const { + return get(); + } + + // Assign OtherValueType + CUTE_HOST_DEVICE constexpr + bit_field& operator=(OtherValueType x) { + return *this = *reinterpret_cast(&x); + } + + // Cast to OtherValueType + CUTE_HOST_DEVICE constexpr + operator OtherValueType () const { + value_type x = get(); + return *reinterpret_cast(&x); + } +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/cuda_types.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/cuda_types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5615fdefe580ecb54a75462a98179bbe81c7c881 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/cuda_types.hpp @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE, CUTE_GCC_UNREACHABLE +#include // cute::integral_constant + +namespace cute +{ + +// +// dim3 +// + +using dim3 = ::dim3; + +// MSVC doesn't define its C++ version macro to match +// its C++ language version. This means that when +// building with MSVC, dim3 isn't constexpr-friendly. +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t& get(dim3& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t const& get(dim3 const& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t&& get(dim3&& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return cute::move(a.x); + } else if constexpr (I == 1) { + return cute::move(a.y); + } else if constexpr (I == 2) { + return cute::move(a.z); + } + + CUTE_GCC_UNREACHABLE; +} + +// Specialize cute::tuple-traits for external types +template <> +struct tuple_size + : integral_constant +{}; + +template +struct tuple_element +{ + using type = uint32_t; +}; + +// +// uint3 +// + +using uint3 = ::uint3; + +template +CUTE_HOST_DEVICE constexpr +uint32_t& get(uint3& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t const& get(uint3 const& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t&& get(uint3&& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return cute::move(a.x); + } else if constexpr (I == 1) { + return cute::move(a.y); + } else if constexpr (I == 2) { + return cute::move(a.z); + } + + CUTE_GCC_UNREACHABLE; +} + +// Specialize cute::tuple-traits for external types +template <> +struct tuple_size + : integral_constant +{}; + +template +struct tuple_element +{ + using type = uint32_t; +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/tuple.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/tuple.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e3dd6d2742033588c101663c27319226f4b8b19e --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/tuple.hpp @@ -0,0 +1,723 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include // cute::true_type, cute::false_type +#include + +#include +#include +//#include // Advanced optimizations + +// cute::tuple is like std::tuple, with differences: +// +// 1. It works on both host and device. +// 2. Its template arguments must be semiregular types. +// 3. It is always a standard-layout type if all of its template arguments are standard-layout types. +// 4. It is always an empty type if all of its template arguments are empty types. +// +// Semiregular types are default constructible and copyable. +// They include "value types" like int or float, +// but do _not_ include references like int& or float&. +// (See std::tie for an example of a tuple of references.) +// +// Standard-layout types preserve ABI across host-device boundaries. They are safe to use as device kernel parameters. +// The standard-layout requirement prevents a more common EBO-based implemented of cute::tuple. +// +// The cute::tuple is also simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of +// the conversion SFINAE, special overloading, and avoiding cvref template types. +// +// Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x. + +namespace cute +{ + +template +struct tuple; + +namespace eso +{ + +// ESO stands for "empty structure optimization." +// We use this technique to ensure that cute::tuple doesn't waste space +// storing template arguments that have no data (like integral_constant). +// Empty types in the template argument list are not even constructed, +// and do not have unique element addresses. Calling `get` +// constructs and returns an instance of an empty type on demand. + +template +struct ESO; + +template +static constexpr bool is_first_empty_v = cute::is_empty::value; +template +static constexpr bool is_rest_empty_v = (cute::is_empty::value && ...); + +template +using ESO_t = ESO, is_rest_empty_v, T...>; + +// Empty First and Empty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() {} + + CUTE_HOST_DEVICE constexpr + ESO(First const&, Rest const&...) {} +}; + +// NonEmpty First and Empty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : first_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const& first, Rest const&...) : first_{first} {} + + First first_; +}; + +// Empty First and NonEmpty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : rest_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const&, Rest const&... rest) : rest_{rest...} {} + + ESO_t rest_; +}; + +// NonEmpty T and NonEmpty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : first_{}, rest_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {} + + First first_; + ESO_t rest_; +}; + +// Get Nth value from ESO +template +CUTE_HOST_DEVICE constexpr +R +getr(S&& s) noexcept +{ + if constexpr (N == 0) { + return static_cast(s).first_; + } else { + return getr(static_cast(s).rest_); + } + CUTE_GCC_UNREACHABLE; +} + +// Compilers disagree on decltype(auto), so these implementations avoid it at cost +template +CUTE_HOST_DEVICE constexpr +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> const&> +getv_cr(ESO const& s) noexcept +{ + if constexpr (cute::is_empty>>::value) { + return {}; + } else { + return getr> const&, N>(s); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> &> +getv_r(ESO& s) noexcept +{ + if constexpr (cute::is_empty>>::value) { + return {}; + } else { + return getr> &, N>(s); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> &&> +getv_rr(ESO&& s) noexcept +{ + if constexpr (cute::is_empty>>::value) { + return {}; + } else { + return getr> &&, N>(static_cast&&>(s)); + } + CUTE_GCC_UNREACHABLE; +} + +} // end namespace eso + +template +struct tuple : eso::ESO_t +{ + CUTE_HOST_DEVICE constexpr + tuple() {} + + CUTE_HOST_DEVICE constexpr + tuple(T const&... t) : eso::ESO_t(t...) {} +}; + +template <> +struct tuple<> {}; + +// +// make_tuple (value-based implementation) +// + +template +CUTE_HOST_DEVICE constexpr +tuple +make_tuple(T const&... t) +{ + return {t...}; +} + +// Returns the element in the ith position of the tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple const& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return eso::getv_cr(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return eso::getv_r(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple&& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return eso::getv_rr(static_cast&&>(t)); +} + +// Returns the first position of type X (as a static integer) in the tuple +// type's argument list. +template +CUTE_HOST_DEVICE constexpr +auto +find(tuple const&) noexcept +{ + return cute::C...>>{}; +} + +// +// Custom is_tuple trait simply checks the existence of tuple_size +// and assumes get(.), tuple_element +// +namespace detail { + +template +auto has_tuple_size( T*) -> bool_constant<(0 <= tuple_size::value)>; +auto has_tuple_size(...) -> false_type; + +} // end namespace detail + +template +struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {}; + +template +static constexpr bool is_tuple_v = cute::is_tuple::value; + +// +// tuple_cat concatenates multiple cute::tuple into a single cute::tuple, +// just like std::tuple_cat for std::tuple. +// + +#if 0 +// Original implementation + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts) +{ + return cute::tuple_cat(cute::tuple_cat(t0,t1),t2,ts...); +} +#endif + +#if 1 +// Extended implementation + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, + index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, + index_sequence, index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, + index_sequence, index_sequence, index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); +} + +template +struct tuple_cat_static; + +template +struct tuple_cat_static, tuple> { + using type = tuple; +}; + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + if constexpr (is_static::value && is_static::value && + is_tuple::value && is_tuple::value) { + return typename detail::tuple_cat_static::type{}; + } else { + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2) +{ + return detail::tuple_cat(t0, t1, t2, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) +{ + return detail::tuple_cat(t0, t1, t2, t3, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4) +{ + return detail::tuple_cat(t0, t1, t2, t3, t4, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts) +{ + return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), cute::tuple_cat(t5, ts...)); +} +#endif + +#if 0 +// Outer-Inner indexing trick to concat all tuples at once + +namespace detail { + +template +struct tuple_cat_helper +{ + static constexpr cute::array ns = {Ns...}; + + static constexpr size_t total_size() { + size_t sum = 0; + for (size_t n : ns) sum += n; + return sum; + } + static constexpr size_t total_size_ = total_size(); + + static constexpr auto values() { + cute::array outer_inner = {}; + + size_t idx = 0; + for (size_t i = 0; i < ns.size(); ++i) { + for (size_t j = 0; j < ns[i]; ++j, ++idx) { + outer_inner[idx][0] = i; + outer_inner[idx][1] = j; + } + } + return outer_inner; + } + static constexpr auto outer_inner_ = values(); + + using total_sequence = make_index_sequence; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(Tuple const& t, index_sequence) +{ + return cute::make_tuple(get(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(Tuples const&... ts) +{ + using Helper = detail::tuple_cat_helper::value...>; + return detail::tuple_cat(cute::make_tuple(ts...), typename Helper::total_sequence{}); +} +#endif + +// +// Equality operators +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +equal_impl(TupleA const& a, TupleB const& b, index_sequence) +{ + return (cute::true_type{} && ... && (get(a) == get(b))); +} + +} // end namespace detail + +template ::value && is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator==(TupleT const& t, TupleU const& u) +{ + if constexpr (tuple_size::value == tuple_size::value) { + return detail::equal_impl(t, u, make_index_sequence::value>{}); + } else { + return cute::false_type{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template ::value ^ is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator==(TupleT const& t, TupleU const& u) +{ + return cute::false_type{}; +} + +template ::value && is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator!=(TupleT const& t, TupleU const& u) +{ + return !(t == u); +} + +template ::value ^ is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator!=(TupleT const& t, TupleU const& u) +{ + return cute::true_type{}; +} + +// +// Comparison operators +// + +// +// There are many ways to compare tuple of elements and because CuTe is built +// on parameterizing layouts of coordinates, some comparisons are appropriate +// only in certain cases. +// -- lexicographical comparison [reverse, reflected, revref] +// -- colexicographical comparison [reverse, reflected, revref] +// -- element-wise comparison [any,all] +// This can be very confusing. To avoid errors in selecting the appropriate +// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. +// +// That said, see int_tuple for more explicitly named common comparison ops. +// + +// +// Display utilities +// + +namespace detail { + +template +CUTE_HOST_DEVICE void print_tuple(Tuple const& t, index_sequence, char s = '(', char e = ')') +{ + using cute::print; + if (sizeof...(Is) == 0) { + print(s); + } else { + ((void(print(Is == 0 ? s : ',')), void(print(get(t)))), ...); + } + print(e); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, index_sequence, char s = '(', char e = ')') +{ + if (sizeof...(Is) == 0) { + os << s; + } else { + (void(os << (Is == 0 ? s : ',') << get(t)), ...); + } + return os << e; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace detail + +template ::value)> +CUTE_HOST_DEVICE void print(Tuple const& t) +{ + return detail::print_tuple(t, make_index_sequence::value>{}); +} + +#if !defined(__CUDACC_RTC__) +template ::value)> +CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) +{ + return detail::print_tuple_os(os, t, make_index_sequence::value>{}); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/type_list.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/type_list.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff9498eb7162ddf5ecfb8f4826dfe142fb6fb164 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/container/type_list.hpp @@ -0,0 +1,125 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE, CUTE_STL_NAMESPACE +#include + +namespace cute +{ + +template +struct type_list {}; + +// get for type_list +// Get an instance of the Ith type in the pack T... +// Requires tuple_element_t> to have std::is_default_constructible +template +CUTE_HOST_DEVICE constexpr +CUTE_STL_NAMESPACE::tuple_element_t> +get(type_list const&) noexcept { + return {}; +} + +// Find the index of the first true in the pack B... +template +struct find_true { + CUTE_HOST_DEVICE static constexpr size_t find() { + size_t i = 0; + (void) ((B ? true : (++i, false)) || ...); + return i; + } + static constexpr size_t value = find(); +}; + +template +static constexpr size_t find_true_v = find_true::value; + +// find for type_list +// Finds the first position of type X (as a static integer) in the T... pack +template +CUTE_HOST_DEVICE constexpr +CUTE_STL_NAMESPACE::integral_constant...>> +find(type_list const&) noexcept { + return {}; +} + +} // end namespace cute + +// +// Specialize tuple-related functionality for cute::type_list +// +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(tuple) +#else +#include +#endif + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/int_tuple.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/int_tuple.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f9d70004450a138e286cb9a12390996a01c3b035 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/int_tuple.hpp @@ -0,0 +1,917 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::array +#include // cute::is_tuple +#include // cute::Int +#include // cute::seq +#include // cute::transform + +/** IntTuple is an integer or a tuple of IntTuples. + * This file holds utilities for working with IntTuples, + * but does not hold a concrete concept or class of IntTuple. + */ + +namespace cute +{ + +// Implementation of get<0>(Integral). +// Even though is_tuple is false and tuple_size doesn't compile, +// CuTe defines rank(Integral) as 1, so it's useful for get<0>(Integral) to return its input +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(T&& t) noexcept +{ + static_assert(I == 0, "Index out of range"); + return static_cast(t); +} + +// Custom recursive get for anything that implements get(.) (for a single integer I). +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(T&& t) noexcept +{ + return get(get(static_cast(t))); +} + +// +// rank +// + +template +CUTE_HOST_DEVICE constexpr +auto +rank(IntTuple const& t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int::value>{}; + } else { + return Int<1>{}; + } + } else { + return rank(get(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using rank_t = decltype(rank(declval())); + +template +static constexpr auto rank_v = rank_t::value; + +// +// shape +// + +template +CUTE_HOST_DEVICE constexpr +auto +shape(IntTuple const& s) +{ + if constexpr (is_tuple::value) { + return transform(s, [](auto const& a) { return shape(a); }); + } else { + return s; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape(IntTuple const& s) +{ + if constexpr (is_tuple::value) { + return shape(get(s)); + } else { + return get(shape(s)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// max +// + +template +CUTE_HOST_DEVICE constexpr +auto +max(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::max(t0, cute::max(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// min +// + +template +CUTE_HOST_DEVICE constexpr +auto +min(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::min(t0, cute::min(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// gcd +// + +template +CUTE_HOST_DEVICE constexpr +auto +gcd(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::gcd(cute::apply(t0, [](auto const&... a){ return cute::gcd(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::gcd(t0, cute::gcd(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// depth +// + +template +CUTE_HOST_DEVICE constexpr +auto +depth(IntTuple const& t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); }); + } else { + return Int<0>{}; + } + } else { + return depth(get(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using depth_t = decltype(depth(declval())); + +template +static constexpr auto depth_v = depth_t::value; + +// +// product +// + +// Implementation of product as a function object +struct Product +{ + template + CUTE_HOST_DEVICE constexpr + auto + operator()(IntTuple const& a) const + { + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 0) { + return Int<1>{}; + } else { + return cute::transform_apply(a, Product{}, multiplies_unary_lfold{}); + } + } else if constexpr (cute::is_integral::value) { + return a; + } + + CUTE_GCC_UNREACHABLE; + } +}; +// Callable product function object +CUTE_INLINE_CONSTANT Product product; + +// Return a rank(t) tuple @a result such that get(@a result) = product(get(@a t)) +template +CUTE_HOST_DEVICE constexpr +auto +product_each(Tuple const& t) +{ + return transform(wrap(t), product); +} + +// Take the product of Tuple at the leaves of TupleG +template +CUTE_HOST_DEVICE constexpr +auto +product_like(Tuple const& tuple, TupleG const& guide) +{ + return transform_leaf(guide, tuple, [](auto const& g, auto const& t) { return product(t); }); +} + +// Return the product of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(IntTuple const& a) +{ + if constexpr (sizeof...(Is) == 0) { + return product(a); + } else { + return size(get(a)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +static constexpr auto size_v = decltype(size(declval()))::value; + +// +// sum +// + +template +CUTE_HOST_DEVICE constexpr +auto +sum(IntTuple const& a) +{ + if constexpr (is_tuple::value) { + return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); }); + } else { + return a; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// inner_product +// + +template +CUTE_HOST_DEVICE constexpr +auto +inner_product(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); }, + [](auto const&... v) { return (Int<0>{} + ... + v); }); + } else { + return a * b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// ceil_div +// + +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); + } else { // tuple int + auto [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + [] (auto const& init, auto const& ai) { + return cute::make_tuple(append(get<0>(init), ceil_div(ai, get<1>(init))), ceil_div(get<1>(init), ai)); + }); + return result; + } + } else + if constexpr (is_tuple::value) { // int tuple + return ceil_div(a, product(b)); + } else { + return (a + b - Int<1>{}) / b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// round_up +// Round @a a up to the nearest multiple of @a b. +// + +template +CUTE_HOST_DEVICE constexpr +auto +round_up(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return round_up(x,y); }); + } else { + return ((a + b - Int<1>{}) / b) * b; + } + + CUTE_GCC_UNREACHABLE; +} + +/** Division for Shapes + * Case Tuple Tuple: + * Perform shape_div element-wise + * Case Tuple Int: + * Fold the division of b across each element of a + * Example: shape_div((4,5,6),40) -> shape_div((1,5,6),10) -> shape_div((1,1,6),2) -> (1,1,3) + * Case Int Tuple: + * Return shape_div(a, product(b)) + * Case Int Int: + * Enforce the divisibility condition a % b == 0 || b % a == 0 when possible + * Return ceil_div(a, b) + */ +template +CUTE_HOST_DEVICE constexpr +auto +shape_div(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); + } else { // tuple int + auto [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + [] (auto const& init, auto const& ai) { + return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); + }); + return result; + } + } else + if constexpr (is_tuple::value) { // int tuple + return shape_div(a, product(b)); + } else { + // Strong divisibility condition + //static_assert((IntTupleA::value % IntTupleB::value == 0) or (IntTupleB::value % IntTupleA::value == 0), "Divisibility Condition"); + + // Weak divisibility condition + if constexpr (is_static::value and is_static::value) { + static_assert(((IntTupleA::value % IntTupleB::value) == 0) or ((IntTupleB::value % IntTupleA::value) == 0), "Divisibility Condition"); + } else { + // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure + //assert((((a % b) == 0) or ((a % b) == 0)) && "Divisibility Condition"); + } + + return (a + b - Int<1>{}) / b; + } + + CUTE_GCC_UNREACHABLE; +} + +/** Return a tuple the same profile as A scaled by corresponding elements in B + */ +template +CUTE_HOST_DEVICE constexpr +auto +elem_scale(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); }); + } else { + return a * product(b); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Test if two IntTuple have the same profile (hierarchical rank division) + */ +template +CUTE_HOST_DEVICE constexpr +auto +congruent(IntTupleA const& a, IntTupleB const& b) +{ + return bool_constant::value>{}; +} + +template +using is_congruent = decltype(congruent(declval(), declval())); + +/** Test if two IntTuple have the similar profiles up to Shape A (hierarchical rank division) + * weakly_congruent is a partial order on A and B: A <= B + */ +template +CUTE_HOST_DEVICE constexpr +auto +weakly_congruent(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_congruent(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return true_type{}; + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return weakly_congruent(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_weakly_congruent = decltype(weakly_congruent(declval(), declval())); + +/** Test if Shape A is compatible with Shape B: + * the size of A and B are the same, and + * any coordinate into A can also be used as a coordinate into B + * Equivalently, the size of Shape B is the same as Shape A at each terminal of Shape A. + * compatible is a partial order on A and B: A <= B + */ +template +CUTE_HOST_DEVICE constexpr +auto +compatible(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return a == size(b); + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return compatible(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_compatible = decltype(compatible(declval(), declval())); + +/** Test if Shape A is evenly divided by Tiler B + * @returns Static or dynamic boolean + * @post if result is true_type, then + * size(a) == logical_divide(make_layout(shape(a)),b) will always compile + * and result in true_type. + */ +template +CUTE_HOST_DEVICE constexpr +auto +evenly_divides(Shape const& a, Tiler const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (rank_v > rank_v) { + return false_type{}; + } else { + return transform_apply(b, a, [](auto const& x, auto const& y) { return evenly_divides(y,x); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else { + return size(a) == size(b) * size(ceil_div(shape(a), b)); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> + */ +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); }); + } else if constexpr (is_constant<0, IntTupleA>::value) { + return repeat_like(b, Int<1>{}); + } else { + return b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tuple const& t) +{ + return filter_zeros(t, t); +} + +// +// Static sorting utilities in detail:: +// + +namespace detail { + +// Some compilers fail to constexpr evaluate quick_sort +// template +// constexpr cute::array quick_sort(cute::array a, int lo = 0, int hi = N-1) { +// if (hi <= lo) return; +// int p = lo; +// for (int i = lo; i < hi; ++i) { +// if (a[i] < a[hi]) { +// T tmp = a[p]; a[p] = a[i]; a[i] = tmp; +// ++p; +// } +// } +// T tmp = a[p]; a[p] = a[hi]; a[hi] = tmp; +// a = quick_sort(a, lo, p-1); +// a = quick_sort(a, p+1, hi); +// return a; +// } + +template +constexpr cute::array exchange_sort(cute::array a) { + for (size_t i = 0; i < N; ++i) { + for (size_t j = i+1; j < N; ++j) { + if (a[j] < a[i]) { + T tmp = a[j]; a[j] = a[i]; a[i] = tmp; + } + } + } + return a; +} + +template >> +struct Sort : Sort, to_seq_t> {}; + +template +struct Sort, seq> { + static_assert(sizeof...(Vs) == sizeof...(Is)); + static constexpr cute::array orig_array = {Vs...}; + static constexpr cute::array sort_array = exchange_sort(orig_array); + using type = seq; +}; + +struct kvpair { + int key, val; + constexpr bool operator<(kvpair const& o) const { return key < o.key; }; +}; + +template >> +struct SortByKey : SortByKey, to_seq_t, to_seq_t> {}; + +template +struct SortByKey, seq, seq> { + static_assert(sizeof...(Ks) == sizeof...(Vs)); + static_assert(sizeof...(Ks) == sizeof...(Is)); + static constexpr cute::array orig_array = {kvpair{Ks,Vs}...}; + static constexpr cute::array sort_array = exchange_sort(orig_array); + using key_type = seq; + using val_type = seq; +}; + +} // end namespace detail + +// +// Converters and constructors with arrays and params +// + +/** Make an IntTuple of rank N from an Indexable array. + * Access elements up to a dynamic index n, then use init (requires compatible types) + * Consider cute::take if all indexing is known to be valid + * \code + * std::vector a = {6,3,4}; + * auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_int_tuple(Indexable const& t, int n, T const& init) +{ + static_assert(N > 0); + if constexpr (N == 1) { + return 0 < n ? t[0] : init; + } else { + return transform(make_seq{}, [&](auto i) { return i < n ? t[i] : init; }); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Fill the dynamic values of a Tuple with values from another Tuple + * \code + * auto params = make_tuple(6,3,4); + * cute::tuple, cute::tuple>, int, Int<2>> result; + * fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +fill_int_tuple_from(Tuple& result, TupleV const& vals) +{ + return fold(result, vals, [](auto const& init, auto&& r) { + if constexpr (is_static>::value) { // Skip static elements of result + return init; + } else if constexpr (is_tuple>::value) { // Recurse into tuples + return fill_int_tuple_from(r, init); + } else { // Assign and consume arg + static_assert(tuple_size>::value > 0, "Not enough values to fill with!"); + r = get<0>(init); + return remove<0>(init); + } + + CUTE_GCC_UNREACHABLE; + }); +} + +/** Make a "Tuple" by filling in the dynamic values in order from the arguments + * \code + * using result_t = cute::tuple, cute::tuple>, int, Int<2>>; + * auto result = make_int_tuple_from(6,3,4); // (_1,(6,3,_3),4,_2) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +Tuple +make_int_tuple_from(Ts const&... ts) +{ + Tuple result = Tuple{}; + fill_int_tuple_from(result, cute::make_tuple(ts...)); + return result; +} + +/** Convert a tuple to a flat homogeneous array of type T + * \code + * auto tup = cute::make_tuple(Int<1>{}, cute::make_tuple(6,3,Int<3>{}),4,Int<2>{}); + * cute::array result = to_array(tup); // [1,6,3,3,4,2] + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +to_array(IntTuple const& t) +{ + auto flat_t = flatten_to_tuple(t); + constexpr int N = tuple_size::value; + cute::array result; + for_each(make_seq{}, [&] (auto i) { result[i] = get(flat_t); }); + return result; +} + +// +// Comparison operators +// + +// +// There are many ways to compare tuple of elements and because CuTe is built +// on parameterizing layouts of coordinates, some comparisons are appropriate +// only in certain cases. +// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout +// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout +// -- element-wise comparison [any,all] : +// This can be very confusing. To avoid errors in selecting the appropriate +// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. +// +// When actually desiring to order coordinates, the user should map them to +// their indices within the Layout they came from: +// e.g. layoutX(coordA) < layoutX(coordB) +// That said, we implement the three most common ways to compare tuples below. +// These are implemented with slighly more explicit names than op<. +// + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less(IntTupleA const& a, IntTupleB const& b); + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less(IntTupleA const& a, IntTupleB const& b); + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less(IntTupleA const& a, IntTupleB const& b); + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleB is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted + } else { + return lex_less(get(a), get(b)) || (get(a) == get(b) && lex_less_impl(a,b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleB is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted + } else { + constexpr size_t A = tuple_size::value - 1 - I; + constexpr size_t B = tuple_size::value - 1 - I; + return colex_less(get(a), get(b)) || (get(a) == get(b) && colex_less_impl(a,b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted + } else { + return elem_less(get(a), get(b)) && elem_less_impl(a,b); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Lexicographical comparison + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::lex_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_leq(T const& t, U const& u) { + return !lex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_gtr(T const& t, U const& u) { + return lex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_geq(T const& t, U const& u) { + return !lex_less(t, u); +} + +// Colexicographical comparison + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::colex_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_leq(T const& t, U const& u) { + return !colex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_gtr(T const& t, U const& u) { + return colex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_geq(T const& t, U const& u) { + return !colex_less(t, u); +} + +// Elementwise [all] comparison + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::elem_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_leq(T const& t, U const& u) { + return !elem_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_gtr(T const& t, U const& u) { + return elem_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_geq(T const& t, U const& u) { + return !elem_less(t, u); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/layout.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/layout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cb161369cbb6b2d2961a126216dfdb3716d4e501 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/layout.hpp @@ -0,0 +1,1931 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include // cute::sizeof_bits + +namespace cute +{ + +// Aliases + +template +using Shape = cute::tuple; + +template +using Stride = cute::tuple; + +template +using Step = cute::tuple; + +template +using Coord = cute::tuple; + +template +using Tile = cute::tuple; + +template +CUTE_HOST_DEVICE constexpr +Shape +make_shape(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Stride +make_stride(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Step +make_step(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Coord +make_coord(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Tile +make_tile(Ts const&... t) +{ + return {t...}; +} + +// +// Layout +// + +template > +struct Layout + : private cute::tuple // EBO for static layouts +{ + // Expensive in compilation time... + //static_assert(is_congruent::value, "Shape and Stride must be congruent"); + + // NOTE: This defaults static Shapes/Strides correctly, but not dynamic + CUTE_HOST_DEVICE constexpr + Layout(Shape const& shape = {}, Stride const& stride = {}) + : cute::tuple(shape, stride) + {} + + // + // Accessors + // + + static constexpr int rank = rank_v; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() { + return get<0,I...>(static_cast&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return get<0,I...>(static_cast const&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() { + return get<1,I...>(static_cast&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const { + return get<1,I...>(static_cast const&>(*this)); + } + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return crd2idx(coord, shape(), stride()); + } + + CUTE_GCC_UNREACHABLE; + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } + + // + // Utility + // + + // + // Index to Coordinate + // + + // NOTE: Only valid for compact layouts + + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post congruent(@a result, shape()) + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_hier_coord(IInt const& idx) const { + return cute::idx2crd(idx, shape(), stride()); + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_flat_coord(IInt const& idx) const { + return cute::crd2crd(this->get_hier_coord(idx), shape(), repeat(Int<1>{})); + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post is_integral::value + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_1d_coord(IInt const& idx) const { + return cute::crd2idx(this->get_hier_coord(idx), shape()); + } + + // + // Coordinate to Coordinate + // + +#if 0 + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post congruent(@a result, shape()) + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_hier_coord(Coord const& crd) const { + return cute::crd2crd(crd, shape(), shape()); + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_flat_coord(Coord const& crd) const { + return cute::crd2crd(crd, shape(), product_each(shape())); + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post is_integral::value + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_1d_coord(Coord const& crd) const { + //return cute::crd2crd(crd, shape(), product(shape())); + return cute::crd2idx(crd, shape()); + } +#endif +}; + +// Equality, return a static or dynamic boolean +template +CUTE_HOST_DEVICE constexpr +auto +operator==(Layout const& layoutA, Layout const& layoutB) +{ + return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride(); +} + +template +struct is_layout : false_type {}; +template +struct is_layout> : true_type {}; + +// +// Layout construction +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, Stride const& stride) +{ + static_assert(is_tuple::value || is_integral::value); + static_assert(is_tuple::value || is_integral::value); + return Layout(shape, stride); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape) +{ + static_assert(is_tuple::value || is_integral::value); + return make_layout(shape, compact_major(shape)); +} + +// +// Convenience tags for common layouts +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, LayoutLeft) +{ + return make_layout(shape, compact_major(shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, LayoutRight) +{ + return make_layout(shape, compact_major(shape)); +} + +// +// Construct a layout from multiple layouts by concatenation +// + +// One argument overload +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& layout0) +{ + return make_layout(make_shape (layout0.shape() ), + make_stride(layout0.stride())); +} + +// Two argument overload +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& layout0, + Layout const& layout1) +{ + return make_layout(make_shape (layout0.shape() , layout1.shape() ), + make_stride(layout0.stride(), layout1.stride())); +} + +// Var argument overload +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& layout0, + Layout const& layout1, + Layout const&... layouts) +{ + return make_layout(make_shape (layout0.shape() , layout1.shape() , layouts.shape()... ), + make_stride(layout0.stride(), layout1.stride(), layouts.stride()...)); +} + +// +// Advanced Layout constructions +// + +// Make a compact layout with shape @a shape and strides following the order induced by @a order. +// Dynamic values in @a order are ignored, considered large, and considered ordered from left to right. +// Example: +// make_ordered_layout(Shape<_2,_2,_2,_2>{}, Step<_0,_2,_3,_1>{}) +// -> (_2,_2,_2,_2):(_1,_4,_8,_2) +// make_ordered_layout(make_shape(2,3,4,5), make_step(Int<2>{}, 67, 42, Int<50>{})) +// -> (2,3,4,5):(_1,10,30,2) +template +CUTE_HOST_DEVICE constexpr +auto +make_ordered_layout(Shape const& shape, Order const& order) +{ + return make_layout(shape, compact_order(shape, order)); +} + +// Make a compact layout with the same shape as @a layout +// and strides following the order induced by @a layout.stride(). +// Static-0 strides in the input @a layout are preserved in the output. +// Example: +// make_layout_like(Layout, Stride<_0,_2,_4,_1>>{}) +// -> (_2,_2,_2,_2):(_0,_2,_4,_1) +// make_layout_like(make_layout(make_shape(2,3,4,5), make_stride(Int<0>{},42,Int<1>{},Int<0>{}))) +// -> (2,3,4,5):(_0,4,_1,_0) +template +CUTE_HOST_DEVICE constexpr +auto +make_layout_like(Layout const& layout) +{ + return make_layout(layout.shape(), + compact_order(filter_zeros(layout.stride(), layout.shape()), layout.stride())); +} + +// Make a compact layout with the same shape as @a layout +// and strides following the order induced by @a layout.stride(), +// except mode-0 is always stride-1 and generated column-major. +// The 0th mode is commonly used for MMA_Atoms or Copy_Atoms so this +// generates the 0th mode with LayoutLeft (preserving stride-0s) regardless of the reference layout +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Layout const& layout) +{ + constexpr int R = Layout::rank; + if constexpr (R > 1 && is_static::value) { + return tiled_product(make_layout(get<0>(layout.shape()), + compact_major(filter_zeros(get<0>(layout.stride()), get<0>(layout.shape())))), + make_ordered_layout(take<1,R>(layout.shape()), take<1,R>(layout.stride()))); + } else { + return make_layout(layout.shape()); + } + + CUTE_GCC_UNREACHABLE; +} + +template ::value || is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Shape const& shape) +{ + return make_layout(shape); +} + +// +// Make an identity layout that maps a coordinate to itself +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_identity_layout(Shape const& shape) +{ + return make_layout(shape, make_basis_like(shape)); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +// Return the Is...th sublayout. +// For Is... = , equivalent to get(...get(get(layout))) +template +CUTE_HOST_DEVICE constexpr +auto +get(Layout const& layout) +{ + return make_layout(get(layout.shape()), + get(layout.stride())); +} + +// Return a new layout with only the modes in the range [B,E) +template +CUTE_HOST_DEVICE constexpr +auto +take(Layout const& layout) +{ + static_assert(B < E, "take: empty range error"); + static_assert(0 <= B && E <= Layout::rank, "take: range out of bounds"); + return make_layout(take(layout.shape()), + take(layout.stride())); +} + +// Return a new layout with only the modes Is... = +template +CUTE_HOST_DEVICE constexpr +auto +select(Layout const& layout) +{ + return make_layout(select(layout.shape()), + select(layout.stride())); +} + +// Return a layout with depth at most 1 +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Layout const& layout) +{ + return make_layout(flatten(layout.shape()), + flatten(layout.stride())); +} + +// Return a layout whose profile is congruent to TargetProfile +// @pre Input layout is flat, flatten(@a layout) == @a layout +// @pre Input layout can be folded to profile, rank(@a layout) == rank(flatten(@a target_profile)) +// @post congruent(@a result, @a target_profile) +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(Layout const& layout, TargetProfile const& target_profile) +{ + return make_layout(unflatten(layout.shape(), target_profile), + unflatten(layout.stride(), target_profile)); +} + +// +// Utilities +// + +// Return the sublayout of mode I... +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(Layout const& layout) +{ + if constexpr (sizeof...(Is) == 0) { + return layout; + } else { + return get(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Layout& layout) +{ + return layout.template shape(); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Layout const& layout) +{ + return layout.template shape(); +} + +// Return the stride of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Layout& layout) +{ + return layout.template stride(); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Layout const& layout) +{ + return layout.template stride(); +} + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(Layout const& layout) +{ + return size(shape(layout)); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(Layout const& layout) +{ + return rank(shape(layout)); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(Layout const& layout) +{ + return depth(shape(layout)); +} + +// Return the coprofile of a mode as a tuple of _0s +// @post congruent(coprofile(@a layout), @a layout(i)) for any i +// @return T Tuple that is congruent with the codomain of @a a. +template +CUTE_HOST_DEVICE constexpr +auto +coprofile(Layout const& layout) +{ + return repeat_like(as_arithmetic_tuple(sum(stride(layout))), Int<0>{}); +} + +// Return the codomain shape of a mode +// @post size(coshape(@a layout)) == cosize(@a layout) +// @return C Coordinate with smallest elements such that +// elem_less(@a sub_layout(c), C) for all c < size(@a sub_layout) +// where @a sub_layout = get(layout). +template +CUTE_HOST_DEVICE constexpr +auto +coshape(Layout const& layout) +{ + auto m1_shapes = transform_leaf( shape(layout), [](auto s) { return s - Int<1>{}; }); + auto abs_strides = transform_leaf(stride(layout), abs_fn{}); + auto co_coord = as_arithmetic_tuple(inner_product(m1_shapes, abs_strides)); + return transform_leaf(co_coord, [](auto c) { return c + Int<1>{}; }); +} + +// Return the codomain size of a mode +// @return M smallest integer such that +// size(@a sub_layout(c)) < M for all c < size(@a sub_layout) +// where @a sub_layout = get(layout). +template +CUTE_HOST_DEVICE constexpr +auto +cosize(Layout const& layout) +{ + return size(coshape(layout)); +} + +template +using cosize_t = decltype(cosize(declval())); + +template +static constexpr auto cosize_v = cosize_t::value; + +// With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& c, Layout const& layout) +{ + return crd2idx(c, layout.shape(), layout.stride()); +} + +// +// Slice and Dice a layout +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& c, Layout const& layout) +{ + return make_layout(slice(c, layout.shape()), + slice(c, layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& c, Layout const& layout) +{ + return cute::make_tuple(slice(c, layout), crd2idx(c, layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +dice(Coord const& c, Layout const& layout) +{ + return make_layout(dice(c, layout.shape()), + dice(c, layout.stride())); +} + +// Compute a pointer offset and (potentially modified) layout from a coordinate +// This exists so it can be overloaded for ComposedLayout +template +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, Layout const& layout) +{ + return cute::make_tuple(layout, layout(coord)); +} + +// +// Transform the modes of a layout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple const& t, F&& f, seq) +{ + return make_layout(f(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f, seq, seq, seq) +{ + return make_layout(f(get(t0),get(t1))..., get(t0)..., get(t1)...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple const& t, F&& f) +{ + return detail::transform_layout(t, f, make_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f) +{ + constexpr int R0 = decltype(rank(t0))::value; + constexpr int R1 = decltype(rank(t1))::value; + constexpr int R = (R0 < R1) ? R0 : R1; + return detail::transform_layout(t0, t1, f, make_seq{}, make_range{}, make_range{}); +} + +// +// Coalesce and Filter +// + +namespace detail { + +// Look at each element and the front of the stack (in order of priority) +// front(NewLayout) get(Layout) +// s0:d0 _1:d1 => continue +// _1:d0 s1:d1 => replace_front s1:d1 +// s0:s1*d1 s1:d1 => replace_front s0*s1:d1 +// s0:d0 s1:d1 => prepend s1:d1 +// +// @pre OldShape and OldStride are flat +template +CUTE_HOST_DEVICE constexpr +auto +bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, + NewShape const& new_shape, NewStride const& new_stride) +{ + if constexpr (I == -1) { + // Base case, we're done + if constexpr (is_constant<1, NewShape>::value) { + return Layout<_1,_0>{}; + } else { + return Layout{new_shape,new_stride}; + } + } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { + // shape(layout) == _1, skip it and continue + return bw_coalesce(old_shape, old_stride, new_shape, new_stride); + } else if constexpr (is_constant<1, NewShape>::value) { + // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) + return bw_coalesce(old_shape, old_stride, get(old_shape), get(old_stride)); + } else if constexpr (is_static(new_shape))>::value && + is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { + // Merge modes because the shapes and strides match + return bw_coalesce(old_shape, old_stride, + replace_front(new_shape, get(old_shape) * get<0>(new_shape)), + replace_front(new_stride, get(old_stride))); + } else { + // Can't replace or merge, so prepend a new mode + return bw_coalesce(old_shape, old_stride, + prepend(new_shape, get(old_shape)), + prepend(new_stride, get(old_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +// cute::coalesce promises to not change the Layout as a function from integers to codomain. +// It accomplishes this inside of the Layout's domain, but not always outside of the domain. +// Example: (_4,_1):(_1,_0) coalesces to _4:_1. +// detail::coalesce_x preserves the Layout function inside its domain and outside. +// +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i, @a layout(i) == @a result(i) +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_x(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + + constexpr int R = decltype(rank(flat_shape))::value; + if constexpr (is_constant<1, decltype(get(flat_shape))>::value) { + return detail::bw_coalesce(flat_shape, flat_stride, Int<2>{}, get(flat_stride)); + } else { + return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); + } + + CUTE_GCC_UNREACHABLE; +} + +// Apply coalesce_x at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_x(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return cute::transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce_x(l,t); }); + } else { + return coalesce_x(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// "Simplify" the layout by combining modes that are possible to combine +// Does not respect the shape of the layout, but does preserve total size +// @post size(@a result) == size(@a layout) +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i < size(@a layout), @a layout(i) == @a result(i) +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + + constexpr int R = decltype(rank(flat_shape))::value; + return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); +} + +// Apply coalesce at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce(l,t); }); + } else { + return coalesce(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// Combine static and dynamic modes of a shape. +// @post size(@a result) == size(@a shape) +// @post depth(@a result) <= 1 +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Shape const& shape) +{ + static_assert(is_integral::value || is_tuple::value); + + return cute::fold_first(flatten(shape), [](auto const& init, auto const& a) { + if constexpr (is_static::value == is_static::value) { + return replace_back(init, back(init) * a); // Both static or both dynamic, coalesce and replace + } else { + return append(init, a); // Can't coalesce, so append + } + + CUTE_GCC_UNREACHABLE; + }); +} + +// Replace the modes in layout that have a 0-stride with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Layout const& layout) +{ + return make_layout(filter_zeros(layout.stride(), layout.shape()), layout.stride()); +} + +// Replace the modes in layout that correspond to a 0 at the terminals of trg_profile with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Layout const& layout, IntTuple const& trg_profile) +{ + return make_layout(filter_zeros(trg_profile, layout.shape()), layout.stride()); +} + +// Remove all of the 0-strides and 1-sizes +// Return 1-shape if empty +template +CUTE_HOST_DEVICE constexpr +auto +filter(Layout const& layout) +{ + return coalesce(filter_zeros(layout)); +} + +// Apply filter at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +filter(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return filter(l,t); }); + } else { + return filter(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Append, Prepend, Replace +// + +template +CUTE_HOST_DEVICE constexpr +auto +append(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(append(layout.shape(), x.shape()), + append(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(append(layout.shape(), x.shape()), + append(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(prepend(layout.shape(), x.shape()), + prepend(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(prepend(layout.shape(), x.shape()), + prepend(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +replace(Layout const& layout, + Layout const& x) +{ + return make_layout(replace(layout.shape(), x.shape()), + replace(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(Layout const& layout) +{ + return make_layout(group(layout.shape()), + group(layout.stride())); +} + +// +// Composition of two layouts: lhs o rhs +// @post compatible(rhs, result) +// @post result(c) = lhs(rhs(c)) +// for all c in the domain of rhs +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +composition_impl(LShape const& lhs_shape, [[maybe_unused]] LStride const& lhs_stride, + RShape const& rhs_shape, RStride const& rhs_stride) +{ + if constexpr (is_tuple::value) { // Right-distributivity of Layout composition for RHS tuple + return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { + return composition_impl(lhs_shape, lhs_stride, s, d); + }); + } else + if constexpr (is_scaled_basis::value) { // Special case for a RHS ScaledBasis stride + return composition_impl(basis_get(rhs_stride, lhs_shape), basis_get(rhs_stride, lhs_stride), + rhs_shape, basis_value(rhs_stride)); + } else + if constexpr (is_constant<0, RStride>::value) { // Special case shortcut for any RHS static stride-0 + return Layout{rhs_shape, rhs_stride}; + } else + if constexpr (is_integral::value) { // Special case shortcut for any LHS integral shape + return Layout{rhs_shape, rhs_stride * lhs_stride}; + } else { // General case: LHS tuple, RHS integral + constexpr int R = tuple_size::value; + + auto [result_shape, result_stride, rest_shape, rest_stride] = + cute::fold(make_seq{}, // t = [0,1,2,...,R-1) + cute::make_tuple(cute::tuple<>{}, // v = (result_shape, + cute::tuple<>{}, // result_stride, + rhs_shape, // rest_shape:Integral, + rhs_stride), // rest_stride:Integral) + [&](auto const& init, auto curr_i) { // f(v,t) -> v' + // Can ICE on some compilers + //auto [result_shape, result_stride, rest_shape, rest_stride] = init; + //auto [curr_shape, curr_stride] = curr; + // Unpack inputs + auto result_shape = get<0>(init); + auto result_stride = get<1>(init); + auto rest_shape = get<2>(init); + auto rest_stride = get<3>(init); + + auto curr_shape = get(lhs_shape); + [[maybe_unused]] auto curr_stride = get(lhs_stride); + + // Strong divisibility condition -- requires composition to be statically verifiable. + //CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition"); + + // Weak divisibility condition -- verify the divisibility condition whenever possible + if constexpr (is_static::value and is_static::value) { + CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition"); + } else { + // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure + //assert((((rest_stride % curr_shape) == 0) or (rest_stride < curr_shape)) && "Stride Divisibility Condition"); + } + + // next_shape: ceil(exclusive_prefix_product(lhs_shape) / rhs_stride) + [[maybe_unused]] auto next_shape = cute::ceil_div(curr_shape, abs(rest_stride)); + // next_stride: ceil(rhs_stride / exclusive_prefix_product(lhs_shape)) + [[maybe_unused]] auto next_stride = cute::ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride); + + if constexpr (is_constant<1, decltype(next_shape)>::value or is_constant<1, decltype(rest_shape)>::value) { + return cute::make_tuple(result_shape, + result_stride, + rest_shape, + next_stride); + } else { + auto new_shape = cute::min(next_shape, rest_shape); + + // Strong divisibility condition + //CUTE_STATIC_ASSERT_V(((rest_shape % new_shape) == Int<0>{}), "Shape Divisibility Condition"); + + // Weak divisibility condition + if constexpr (is_static::value and is_static::value) { + CUTE_STATIC_ASSERT_V(((rest_shape % new_shape) == Int<0>{}), "Shape Divisibility Condition"); + } else { + // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure + //assert(((rest_shape % new_shape) == 0) && "Shape Divisibility Condition"); + } + + return cute::make_tuple(append(result_shape, new_shape), + append(result_stride, rest_stride * curr_stride), + rest_shape / new_shape, + next_stride); + } + + CUTE_GCC_UNREACHABLE; + }); + + if constexpr (tuple_size::value == 0) { + return Layout{rest_shape, rest_stride * get(lhs_stride)}; + } else + if constexpr (is_constant<1, decltype(rest_shape)>::value) { + return Layout{unwrap(result_shape), unwrap(result_stride)}; + } else { + return Layout{append(result_shape, rest_shape), + append(result_stride, rest_stride * get(lhs_stride))}; + } + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + Layout const& rhs) +{ + auto flat_lhs = detail::coalesce_x(lhs, coprofile(rhs)); + return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs.shape(), rhs.stride()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + Tiler const& rhs) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + // Drop any modes of lhs that aren't hit by rhs + return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq::value>{}, seq<>{}, seq<>{}); + } else if constexpr (is_underscore::value) { + return lhs; + } else if constexpr (is_integral::value) { + auto flat_lhs = detail::coalesce_x(lhs); + return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs, Int<1>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Complement +// +// Build the complement of a layout. +// @post size(@a result) >= @a cosize_hi / size(filter(@a layout))); +// @post For all i in [1,size(@a result)), +// @a result(i) < @a result(i-1) +// For all j in [0, size(@a layout)), +// @a result(i) != @a layout(j) +// + +namespace detail { + +// @pre @a layout has been filtered (flattened and no stride-0 or size-1 modes). +template +CUTE_HOST_DEVICE constexpr +auto +complement(Shape const& shape, Stride const& stride, CoTarget const& cotarget) +{ + if constexpr (is_constant<0, Stride>::value) { + // Special case for irreducible rank-1 stride-0 layout + return make_layout(coalesce(cotarget)); + } else { + // General case + constexpr int R = rank_v; + static_assert(R == 1 || is_static::value, + "Dynamic-stride complement only for rank-1 layouts"); + + // Should just be a sort and a fold... + // Then we could even handle dynamic strides (but they would destroy all static strides) + auto [shape_, stride_, result_shape_, result_stride] = + fold(make_seq{}, + cute::make_tuple(shape, stride, cute::make_tuple(), cute::make_tuple(Int<1>{})), + [](auto const& init, auto i) + { + auto [shape, stride, result_shape, result_stride] = init; + auto min_stride = cute::min(stride); + auto min_idx = cute::find(stride, min_stride); + auto new_shape = min_stride / get(result_stride); + auto new_stride = min_stride * get(shape); + static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); + + return cute::make_tuple(remove(shape), // Remove the min_idx from shape + remove(stride), // Remove the min_idx from stride + append(result_shape , new_shape ), // new shape = min_stride / last_stride + append(result_stride, new_stride)); // new stride = min_stride * curr_shape + }); + + // Append the last shape mode + auto new_shape = get<0>(stride_) / get(result_stride); // new shape = min_stride / last_stride + static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); + auto result_shape = append(result_shape_, new_shape); + + // Compute the rest_shape and rest_stride + auto new_stride = get<0>(stride_) * get<0>(shape_); // new stride = min_stride * curr_shape + auto rest_shape = coalesce(ceil_div(cotarget, new_stride)); + auto rest_stride = compact_major(rest_shape, new_stride); + + // Coalesce and append (rest_shape, rest_stride) + return coalesce(make_layout(make_shape (result_shape , rest_shape ), + make_stride(result_stride, rest_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout, CoTarget const& cotarget) +{ + auto filter_layout = filter(layout); + return detail::complement(filter_layout.shape(), filter_layout.stride(), shape(cotarget)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout) +{ + auto filter_layout = filter(layout); + return detail::complement(filter_layout.shape(), filter_layout.stride(), cosize(filter_layout)); +} + +// +// Right-Inverse and Left-Inverse +// + +// +// Build the right-inverse of a layout +// @pre is_static +// @result A layout @a result such that +// @a layout(@a result(i)) == i for all i < size(@a result) +// @result A layout @a result such that +// composition(@a layout, @a result) is identical to make_layout(shape(result)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(Layout const& layout) +{ + // Flatten and filter shape-1 + auto clayout = coalesce(layout); + auto lstride = wrap(clayout.stride()); + auto lshape = wrap(clayout.shape()); + + // Prefix product of the shape + auto preprod_shape = cute::fold(lshape, cute::tuple<_1>{}, [](auto c, auto vi) { return append(c, vi*back(c)); }); + + // Filter out any dynamic strides + [[maybe_unused]] auto filtered_seq = filter_tuple(make_seq{}, lstride, [](auto i, auto d) { + return conditional_return>(cute::tuple{i}, cute::tuple<>{}); }); + [[maybe_unused]] auto filtered_stride = transform(filtered_seq, [&](auto i) { return get(lstride); }); + + // Sort by strides + using Sorted = detail::SortByKey; + auto sorted_seq = typename Sorted::val_type{}; + //auto sorted_stride = typename Sorted::key_type{}; + + auto [result_shape, result_stride, curr] = cute::fold(sorted_seq, tuple,tuple<_0>,_1>{}, + [&](auto const& init, auto i) { + [[maybe_unused]] auto ishape = get(lshape); + [[maybe_unused]] auto istride = get(lstride); + [[maybe_unused]] auto curr_stride = get<2>(init); + + if constexpr (is_constant::value) { + return make_tuple(append(get<0>(init), ishape), // result_shape + append(get<1>(init), get(preprod_shape)), // result_stride + ishape * istride); + } else { + return init; + } + + CUTE_GCC_UNREACHABLE; + }); + + return coalesce(make_layout(result_shape, result_stride)); +} + +CUTE_HOST_DEVICE constexpr +auto +right_inverse(Underscore const& _) +{ + return _; +} + +// +// Build the quasi-inverse of a layout (left-inverse when layout is injective) +// @pre is_static +// @result A layout @a result such that +// @a layout(@a result(@a layout(i))) == @a layout(i) for all i < size(@a layout) +// @result A layout @a result such that +// composition(@layout, composition(@a result, @a layout)) is identical to @a layout +// + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(Layout const& layout) +{ + // Flatten and filter shape-1 + auto clayout = coalesce(layout); + auto lstride = wrap(clayout.stride()); + auto lshape = wrap(clayout.shape()); + + // Prefix product of the shape + auto preprod_shape = cute::fold(lshape, cute::tuple<_1>{}, [](auto c, auto vi) { return append(c, vi*back(c)); }); + + // Sort by strides + static_assert(is_static::value, "Left inverse requires static strides."); + using Sorted = detail::SortByKey>; + auto sorted_seq = typename Sorted::val_type{}; + //auto sorted_stride = typename Sorted::key_type{}; + + auto [result_shape, result_stride] = cute::fold(sorted_seq, tuple,tuple<_0>>{}, + [&](auto const& init, auto i) { + [[maybe_unused]] auto istride = get(lstride); + + if constexpr (is_constant<0, decltype(istride)>::value) { + return init; + } else { + auto result_shape = get<0>(init); + auto result_stride = get<1>(init); + + CUTE_STATIC_ASSERT_V((istride % size(result_shape)) == Int<0>{}, "Left inverse divisibility condition"); + + return make_tuple(append(result_shape, istride / size(result_shape)), + append(result_stride, get(preprod_shape))); + } + + CUTE_GCC_UNREACHABLE; + }); + + return coalesce(make_layout(append(result_shape, get(lshape)), + result_stride)); +} + +CUTE_HOST_DEVICE constexpr +auto +left_inverse(Underscore const& _) +{ + return _; +} + +// +// Max Common Layout +// + +/* Return a layout that points to the maximum number of contiguous elements + * that logically correspond in the layouts of @a a and @a b. + * + * @returns Layout R + * @post For all 0 <= i < size(R), a(R(i)) == i and b(R(i)) == i + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(Layout const& a, + Layout const& b) +{ + Layout inv_b = right_inverse(b); + Layout common = coalesce(composition(a, inv_b)); + + // Keep only the static identity component of the common layout + if constexpr (is_static(common))>::value && + is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return composition(inv_b, layout<0>(common)); + } else { + return Layout<_1,_0>{}; + } +} + +/* Return Int such that N is the maximum number of contiguous elements + * that logically correspond in the layouts of @a a and @a b. + * + * @returns Int with N >= 1 + * @post For all 0 <= n < N, a(b.get_1d_coord(n)) == n + * (NOTE: Problems with negative strides/coords in this post-condition) + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Layout const& a, + Layout const& b) +{ + Layout common = coalesce(composition(a, right_inverse(b))); + + // Keep only the static identity component of the common layout + if constexpr (is_static(common))>::value && + is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return shape<0>(common); + } else { + return Int<1>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +/* Return a layout that distributes ShapeB over ShapeA. + * + * @returns Layout result + * @post evenly_divides(@a b, size(@a result)) + * @post evenly_divides(@a a, @a result) + * @post For all i,j in [0,size(@a result)) with i < j, @a result(i) < @a result(j). Surjective and Ordered. + * @post composition(make_layout(shape(@a a)), @a result) is admissible + * \code + * // Note that 6 does not divide this shape + * Layout layoutA = Layout,Int<14>>>{}; + * + * // Want to tile any 6 elements and don't care where they come from + * Layout dist = domain_distribute(layoutA, Int<6>{}); // (_3,_2):(_1,_15) + * + * // Not guaranteed to find all 6 though... + * CUTE_STATIC_ASSERT_V(Int<6>{} == size(dist)); + * + * Layout result = zipped_divide(layoutA, dist); // (_6,Rest) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +domain_distribute(ShapeA const& a, ShapeB const& b) +{ + static_assert(is_integral::value); + static_assert(is_static::value); + + auto flat_shape_a = flatten(shape(a)); + + static_assert(is_static::value); + + // Compute the shape of the result + auto [result_shape, b_rest] = cute::fold(flat_shape_a, cute::make_tuple(cute::tuple<>{}, size(b)), [](auto init, auto a_) { + auto [result, b_] = init; + auto gcd_ = gcd(a_, b_); + return cute::make_tuple(append(result, gcd_), b_ / gcd_); + }); + + // Compute the stride of the result + auto result_stride = compact_major(flat_shape_a); + + return coalesce(make_layout(result_shape, result_stride)); +} + +// +// Kernel (Nullspace) of a Layout +// + +/** Return a layout that represents the nullspace of @a layout + * @post @a layout(@a result(i)) == 0 for all i < size(@a result) + * @post nullspace(@a result) == Layout<_1,_0>{} + * @post size(@a result) == size(@a layout) / size(filter(@a layout)) + */ +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(Layout const& layout) +{ + [[maybe_unused]] auto flat_stride = flatten(layout.stride()); + + // Select all indices corresponding to stride-0s + auto iseq = cute::fold(make_seq>{}, cute::tuple<>{}, + [&](auto init, auto i){ + if constexpr (is_constant_v<0, decltype(get(flat_stride))>) { return append(init, i); } + else { return init; } + CUTE_GCC_UNREACHABLE; + }); + + if constexpr (tuple_size::value == 0) { + return Layout<_1,_0>{}; // Empty case, nothing found + } else { + // Generate the corresponding new strides and construct + auto flat_shape = flatten(layout.shape()); + auto rstride = compact_major(flat_shape); + return make_layout(unwrap(transform(iseq, [&](auto i) { return get(flat_shape); })), + unwrap(transform(iseq, [&](auto i) { return get(rstride); }))); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Zip +// + +template +CUTE_HOST_DEVICE constexpr +auto +zip(Layout const& layout) +{ + return make_layout(zip(layout.shape()), + zip(layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(Layout const& layoutA, + Layout const& layoutB) +{ + return make_layout(zip(layoutA.shape(), layoutB.shape()), + zip(layoutA.stride(), layoutB.stride())); +} + +// +// Tile unzip +// Logical product and logical divide (on layouts) produce rank-2 results by design. +// Follow the profile of @a tile and zip the rank-2 modes located at the terminals into +// their own mode. +// + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(Layout const& layout, + Tiler const& tiler) +{ + return make_layout(zip2_by(layout.shape(), tiler), + zip2_by(layout.stride(), tiler)); +} + +// +// Logical divide +// + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Layout const& layout, + Layout const& tiler) +{ + return composition(layout, make_layout(tiler, complement(tiler, shape(coalesce(layout))))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Layout const& layout, + Tiler const& tiler) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_divide: Too many modes in tiler."); + return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_divide(l,t); }); + } else if constexpr (is_underscore::value) { + return layout; + } else if constexpr (is_integral::value) { + return logical_divide(layout, make_layout(tiler)); + } + + CUTE_GCC_UNREACHABLE; +} + +// Generalization of ceil_div for Layout lhs +// is effectively the "rest mode" of logical_divide. +// Occurs in the calculation of gridDim, for example, for generalized tilers +// Example: +// dim3 gridDim(size(ceil_div(problem_shape_M, cta_tiler_M)), +// size(ceil_div(problem_shape_N, cta_tiler_N))); +// This does not consider compositional acceptance, so it may be the case that +// ceil_div produces a result while logical_divide (and friends) do not. +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(Target const& target, + Layout const& tiler) +{ + return shape(complement(tiler, shape(target))); +} + +// +// Convenience operator +// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// by gathering the tile modes and residuals into a rank-2 result. +// + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(Layout const& layout, + Tiler const& tiler) +{ + return tile_unzip(logical_divide(layout, tiler), tiler); +} + +// Same as zipped_divide, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(Layout const& layout, + Tiler const& tiler) +{ + auto result = zipped_divide(layout, tiler); + + auto R1 = rank<1>(result); + return result(_, repeat(_)); +} + +// Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +flat_divide(Layout const& layout, + Tiler const& tiler) +{ + auto result = zipped_divide(layout, tiler); + + auto R0 = rank<0>(result); + auto R1 = rank<1>(result); + return result(repeat(_), repeat(_)); +} + +// +// Logical product +// + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& block, + Layout const& tiler) +{ + return make_layout(block, composition(complement(block, size(block)*cosize(tiler)), tiler)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& block, + Tiler const& tiler) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_product: Too many modes in tiler."); + return transform_layout(block, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); }); + } else if constexpr (is_underscore::value) { + return block; + } else if constexpr (is_integral::value) { + return logical_product(block, make_layout(tiler)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Convenience operator +// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// by gathering the block modes and products into a rank-2 result. +// + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_product(Layout const& block, + Tiler const& tiler) +{ + return tile_unzip(logical_product(block, tiler), tiler); +} + +// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(Layout const& block, + Tiler const& tiler) +{ + auto result = zipped_product(block, tiler); + + auto R1 = rank<1>(result); + return result(_, repeat(_)); +} + +// Same as zipped_product, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +flat_product(Layout const& block, + Tiler const& tiler) +{ + auto result = zipped_product(block, tiler); + + auto R0 = rank<0>(result); + auto R1 = rank<1>(result); + return result(repeat(_), repeat(_)); +} + +// +// Rank-sensitive products +// + +// blocked_product -- Reproduce a block over a tiler. +// Think of every element of "tiler" as a "block" +// and return the layout of the resulting structure. +// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler)) +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(Layout const& block, + Layout const& tiler) +{ + constexpr int R = cute::max(rank_v, rank_v); + + auto result = logical_product(append(block), append(tiler)); + + return zip(get<0>(result), get<1>(result)); +} + +// raked_product -- Reproduce a block over a tiler with block-interleaving. +// Think of every element of "tiler" as a "block", interleave those blocks, +// and return the layout of the resulting structure. +// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler)) +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(Layout const& block, + Layout const& tiler) +{ + constexpr int R = cute::max(rank_v, rank_v); + + auto result = logical_product(append(block), append(tiler)); + + return zip(get<1>(result), get<0>(result)); +} + +// tile_to_shape -- Perform a product of a layout so that the result matches a target shape. +// This is similar to blocked_product, but specifies the result shape instead of the +// product shape, which is more convenient in certain circumstances. +// @param block The layout to repeat +// @param trg_shape The target shape of the result +// @param ord_shape The order of the modes of @a trg_shape to tile @a layout with. +// Defaults to GenColMajor, so @a layout will repeat +// across the first mode first, the second mode second, etc +// E.g. Step<_2,_1,_3> will cause @a layout to repeat +// across the second mode first, the first mode second, and the third mode last. +// @pre rank(@a block) <= rank(@a trg_shape) +// @post compatible(@a trg_shape, shape(@a result)) +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(Layout const& block, + TrgShape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + CUTE_STATIC_ASSERT_V(rank(block) <= rank(trg_shape), "Rank of layout must be <= rank of target shape."); + constexpr int R = rank_v; + + auto padded_block = append(block); + + auto block_shape = product_each(shape(padded_block)); + auto target_shape = product_each(shape(trg_shape)); + + // Assert proper division + if constexpr (is_static::value) { + CUTE_STATIC_ASSERT_V(evenly_divides(target_shape, block_shape), + "tile_to_shape: block shape does not divide the target shape."); + } + + auto product_shape = ceil_div(target_shape, block_shape); + + return blocked_product(padded_block, make_ordered_layout(product_shape, ord_shape)); +} + +// +// Upcast +// For stride-1 mode, divide size by N. Divide all other strides by N. +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { // tuple stride + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_constant<0, Stride>::value) { // static-0 stride + return Layout{shape,stride}; + } else if constexpr (is_static::value) { // static stride + static_assert(Stride::value % N == 0 or N % Stride::value == 0, "Divisibility condition"); + return make_layout(ceil_div(shape, ceil_div(Int{}, abs(stride))), + signum(stride) * ceil_div(abs(stride), Int{})); + } else { // dynamic stride + // Assume dynamic strides are larger than N and divisible + // assert(stride % N == 0); + return make_layout(shape, safe_div(stride, Int{})); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Layout const& layout) +{ + return upcast(layout.shape(), layout.stride()); +} + +// +// Downcast +// For stride-1 mode, multiply size by N. Multiply all other strides by N. +// + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return downcast(s,d); }); + } else if constexpr (is_constant<1, Stride>::value || is_constant<-1, Stride>::value) { + return make_layout(shape * Int{}, stride); + } else { + return make_layout(shape, stride * Int{}); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Layout const& layout) +{ + CUTE_STATIC_ASSERT(has_int1::value, "Downcast requires adjacent elements"); + return downcast(layout.shape(), layout.stride()); +} + +// +// Recast +// + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(Layout const& layout) +{ + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { + return layout; + } + else if constexpr (scale::num == 1) { + return downcast(layout); + } + else if constexpr (scale::den == 1) { + return upcast(layout); + } + else { + return downcast(upcast(layout)); + } + + CUTE_GCC_UNREACHABLE; +} + +// Determine the maximum alignment of a Layout. +// The maximum alignment is the largest N for which upcast(layout) will compile. +// upcast(layout) compiles when the static shapes and strides pass divisibility checks. +// Therefore, upcast(layout) will also compile for all divisors M of N. +// Note that this only considers the static shapes and strides of the Layout +// in symmetry with upcast only checking against static shapes and strides and assuming all +// dynamic shapes and strides are large and multiples of N. +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(Layout const& layout) +{ + auto flat_layout = coalesce(layout); + auto static_shape = transform( shape(flat_layout), [](auto s){ return conditional_return::value>(s, Int<1>{}); }); + auto static_stride = transform(stride(flat_layout), [](auto d){ return conditional_return::value>(d, Int<0>{}); }); + auto filter_layout = make_layout(static_shape, static_stride); + auto permuted = logical_divide(filter_layout, right_inverse(filter_layout)); + return gcd(size<0>(permuted), stride<1>(permuted)); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(Layout const& layout) +{ + print(layout.shape()); print(":"); print(layout.stride()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout const& layout) +{ + return os << shape(layout) << ":" << stride(layout); +} +#endif + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/layout_composed.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/layout_composed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6a9677831bf2b19c971223f18dc2c929b74fa4ab --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/layout_composed.hpp @@ -0,0 +1,661 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE, CUTE_GCC_UNREACHABLE +#include // cute::tuple +#include // cute::true_type, cute::false_type, cute::Int + +/* This implements a ComposedLayout of the form + * LayoutA o Offset o LayoutB + * and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB. + * For example, when the "divisibility condition" is violated in composition(LayoutA, LayoutB). + * + * This ComposedLayout provides similar functionality to Layout including tiling, partitioning, + * coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout. + * For example, this layout provides shape() and size() functions, but does not provide stride() functions. + * Mostly, the similar functionality is accomplished by applying each operation to LayoutB only + * as LayoutB defines the domain. + */ + +namespace cute +{ + +// A Layout of non-trivially composable functions: F o I o L +template +struct ComposedLayout : private cute::tuple // EBO for static layouts +{ + CUTE_HOST_DEVICE constexpr + ComposedLayout(LayoutA const& layoutA = {}, + Offset const& offset = {}, + LayoutB const& layoutB = {}) + : cute::tuple(layoutA, offset, layoutB) + {} + + // + // Accessors + // + + static constexpr int rank = LayoutB::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_a() const { + return get<0>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + offset() const { + return get<1>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_b() const { + return get<2>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return layout_b().shape(); + } + + // Doesn't really make sense to ask for the strides of this "layout" + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const = delete; + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return layout_a()(offset() + layout_b()(coord)); // (A o O o B)(c) + } + + CUTE_GCC_UNREACHABLE; + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } + + // Equality, return a static or dynamic boolean + template + CUTE_HOST_DEVICE constexpr + auto + operator==(ComposedLayout const& other) const { + return this->layout_a() == other.layout_a() && + this->layout_b() == other.layout_b() && + this->offset() == other.offset(); + } +}; + +template +struct is_layout> : true_type {}; + +template +struct is_composed_layout : false_type {}; +template +struct is_composed_layout> : true_type {}; + +// +// Constructors +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_composed_layout(LayoutA const& layoutA, + Offset const& offset, + LayoutB const& layoutB) +{ + return ComposedLayout{layoutA, offset, layoutB}; +} + +// +// Utilities +// + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(ComposedLayout const& clayout) +{ + return composition(clayout.layout_a(), clayout.offset(), layout(clayout.layout_b())); +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(ComposedLayout const& layout) +{ + return shape(layout.layout_b()); +} + +// Doesn't make sense to directly ask for the strides of this "layout" +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(ComposedLayout const& layout) = delete; + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +size(ComposedLayout const& layout) +{ + return size(layout.layout_b()); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(ComposedLayout const& layout) +{ + return rank(layout.layout_b()); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(ComposedLayout const& layout) +{ + return depth(layout.layout_b()); +} + +// Return the codomain size of a mode +template +CUTE_HOST_DEVICE constexpr +auto +cosize(ComposedLayout const& layout) +{ + return cosize(layout.layout_b()); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +template +CUTE_HOST_DEVICE constexpr +auto +get(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), get(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), take(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), flatten(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(ComposedLayout const& a, X const& x) +{ + return composition(a.layout_a(), a.offset(), append(a.layout_b(), x)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), group(a.layout_b())); +} + +// +// Slice a ComposedLayout +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& coord, ComposedLayout const& layout) +{ + auto [slice, offset] = slice_and_offset(coord, layout.layout_b()); + return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + offset, slice}, Int<0>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& coord, ComposedLayout const& layout) +{ + return get<0>(slice_and_offset(coord, layout)); +} + +// Compute a pointer offset and (potentially modified) layout from a coordinate +// For composed layout tensors the offset is accumulated in the layout itself while pointer is not updated +template +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, ComposedLayout const& layout) +{ + return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + layout.layout_b()(coord), layout.layout_b()}, Int<0>{}); +} + +// +// composition +// + +template +CUTE_HOST_DEVICE constexpr +auto +composition(LayoutA const& layoutA, + Offset const& offset, + LayoutB const& layoutB) +{ + return ComposedLayout{layoutA, offset, layoutB}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), composition(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& a, + ComposedLayout const& b) +{ + CUTE_STATIC_ASSERT_V(b.offset() == Int<0>{}, "Require offset == 0."); + + return composition(composition(a, b.layout_a()), b.layout_b()); +} + +// +// complement +// + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout, CoTarget const& cotarget) +{ + return complement(layout.layout_b(), cotarget); +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout) +{ + return complement(layout, cosize(layout)); +} + +// +// inverse +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(ComposedLayout const& layout) +{ + return composition(right_inverse(layout.layout_b()), right_inverse(layout.offset()), right_inverse(layout.layout_a())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(ComposedLayout const& layout) +{ + return composition(left_inverse(layout.layout_b()), left_inverse(layout.offset()), left_inverse(layout.layout_a())); +} + +// +// Other operations +// + +template +CUTE_HOST_DEVICE constexpr +auto +zip(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), zip(a.layout_b())); +} + +// Partitions + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), logical_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tile_unzip(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tiled_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flat_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), logical_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), zipped_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tiled_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flat_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), flat_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), blocked_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), raked_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(ComposedLayout const& layout, + Shape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + return composition(layout.layout_a(), layout.offset(), tile_to_shape(layout.layout_b(), trg_shape, ord_shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.layout_a(), layout.offset(), filter(layout.layout_b(), trg_profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout) +{ + return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b(), trg_profile)); +} + + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout const& layout) +{ + return composition(upcast(layout.layout_a()), upcast(layout.offset()), upcast(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout const& layout) +{ + return composition(downcast(layout.layout_a()), downcast(layout.offset()), downcast(layout.layout_b())); +} + + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(ComposedLayout const& layout) +{ + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { + return layout; + } + else if constexpr (scale::num == 1) { + return downcast(layout); + } + else if constexpr (scale::den == 1) { + return upcast(layout); + } + else { + return downcast(upcast(layout)); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(ComposedLayout const& layout) +{ + // Do not attempt for general ComposedLayouts + //return gcd(max_alignment(layout.layout_a()), max_alignment(layout.offset()), max_alignment(layout.layout_b())); + return Int<1>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(ComposedLayout const& layout) +{ + // Do not attempt for general ComposedLayouts + return Layout<_1,_0>{}; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ComposedLayout const& layout) +{ + print(layout.layout_a()); print(" o "); print(layout.offset()); print(" o "); print(layout.layout_b()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) +{ + return os << layout.layout_a() << " o " << layout.offset() << " o " << layout.layout_b(); +} +#endif + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/arithmetic_tuple.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/arithmetic_tuple.hpp new file mode 100644 index 0000000000000000000000000000000000000000..016ac5b6cba80cf249698ea4b075840b8d6ed070 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/arithmetic_tuple.hpp @@ -0,0 +1,535 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace cute +{ + +template +struct ArithmeticTuple : public tuple { + CUTE_HOST_DEVICE constexpr + ArithmeticTuple() : tuple() {} + + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(tuple const& t) : tuple(t) {} + + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(T const&... t) : tuple(t...) {} +}; + +template +struct is_tuple> : true_type {}; + +template +struct is_flat> : is_flat> {}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_arithmetic_tuple(T const&... t) { + return ArithmeticTuple(t...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(T const& t) { + if constexpr (is_tuple::value) { + return detail::tapply(t, [](auto const& x){ return as_arithmetic_tuple(x); }, + [](auto const&... a){ return make_arithmetic_tuple(a...); }, + tuple_seq{}); + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Numeric operators +// + +// Addition +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, tuple const& u) { + return t + ArithmeticTuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(tuple const& t, ArithmeticTuple const& u) { + return ArithmeticTuple(t) + u; +} + +// Subtraction +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), minus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t, tuple const& u) { + return t - ArithmeticTuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator-(tuple const& t, ArithmeticTuple const& u) { + return ArithmeticTuple(t) - u; +} + +// Negation +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t) { + return transform_apply(t, negate{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +// +// Special cases for C<0> +// + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple +operator+(C, ArithmeticTuple const& u) { + static_assert(t == 0, "Arithmetic tuple op+ error!"); + return u; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple +operator+(ArithmeticTuple const& t, C) { + static_assert(u == 0, "Arithmetic tuple op+ error!"); + return t; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple +operator-(C, ArithmeticTuple const& u) { + static_assert(t == 0, "Arithmetic tuple op- error!"); + return -u; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple +operator-(ArithmeticTuple const& t, C) { + static_assert(u == 0, "Arithmetic tuple op- error!"); + return t; +} + +// +// ArithmeticTupleIterator +// + +template +struct ArithmeticTupleIterator +{ + using value_type = ArithTuple; + using element_type = ArithTuple; + using reference = ArithTuple; + + ArithTuple coord_; + + CUTE_HOST_DEVICE constexpr + ArithmeticTupleIterator(ArithTuple const& coord = {}) : coord_(coord) {} + + CUTE_HOST_DEVICE constexpr + ArithTuple operator*() const { return coord_; } + + template + CUTE_HOST_DEVICE constexpr + auto operator[](Coord const& c) const { return *(*this + c); } + + template + CUTE_HOST_DEVICE constexpr + auto operator+(Coord const& c) const { + return ArithmeticTupleIterator>(coord_ + c); + } +}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_inttuple_iter(Ts const&... ts) { + return ArithmeticTupleIterator(as_arithmetic_tuple(ts...)); +} + +// +// ArithmeticTuple "basis" elements +// A ScaledBasis is a (at least) rank-N+1 ArithmeticTuple: +// (_0,_0,...,T,_0,...) +// with value T in the Nth mode + +template +struct ScaledBasis : private tuple +{ + CUTE_HOST_DEVICE constexpr + ScaledBasis(T const& t = {}) : tuple(t) {} + + CUTE_HOST_DEVICE constexpr + decltype(auto) value() { return get<0>(static_cast &>(*this)); } + CUTE_HOST_DEVICE constexpr + decltype(auto) value() const { return get<0>(static_cast const&>(*this)); } + + // Deprecated: Get the first hierarchical mode in this basis. + CUTE_HOST_DEVICE static constexpr + auto mode() { return get<0>(int_sequence{}); } +}; + +// Ensure flat representation +template +struct ScaledBasis, Ns...> : ScaledBasis {}; + +template +struct is_scaled_basis : false_type {}; +template +struct is_scaled_basis> : true_type {}; + +template +struct is_integral> : true_type {}; + +// Shortcuts +// E<> := _1 +// E<0> := (_1,_0,_0,...) +// E<1> := (_0,_1,_0,...) +// E<0,0> := ((_1,_0,_0,...),_0,_0,...) +// E<0,1> := ((_0,_1,_0,...),_0,_0,...) +// E<1,0> := (_0,(_1,_0,_0,...),_0,...) +// E<1,1> := (_0,(_0,_1,_0,...),_0,...) +template +using E = ScaledBasis,Ns...>; + +// Apply the Ns... pack to another Tuple +template +CUTE_HOST_DEVICE decltype(auto) +basis_get(T const&, Tuple&& t) +{ + return static_cast(t); +} + +template +CUTE_HOST_DEVICE decltype(auto) +basis_get(ScaledBasis const&, Tuple&& t) +{ + if constexpr (sizeof...(Ns) == 0) { + return static_cast(t); + } else { + return get(static_cast(t)); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE decltype(auto) +basis_value(T const& e) { + if constexpr (is_scaled_basis::value) { + return e.value(); + } else { + return e; + } + CUTE_GCC_UNREACHABLE; +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +to_atuple_i(T const& t, seq) { + return make_arithmetic_tuple((void(I),Int<0>{})..., t); +} + +} // end namespace detail + +// Turn a ScaledBases into a rank-N+1 ArithmeticTuple +// with N prefix 0s: (_0,_0,...N...,_0,T) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return t.value(); +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return detail::to_atuple_i(as_arithmetic_tuple(ScaledBasis{t.value()}), make_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_basis_like(Shape const& shape) +{ + if constexpr (is_tuple::value) { + // Generate bases for each mode of shape + return transform(tuple_seq{}, shape, [](auto I, auto si) { + // Generate bases for each si and add an i on end + return make_basis_like(si); + }); + } else { + return E{}; + } + CUTE_GCC_UNREACHABLE; +} + +// +// Arithmetic +// + +template +CUTE_HOST_DEVICE constexpr +auto +safe_div(ScaledBasis const& b, U const& u) +{ + auto t = safe_div(b.value(), u); + return ScaledBasis{t}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(ScaledBasis const& b, U const& u) +{ + auto t = ceil_div(b.value(), u); + return ScaledBasis{t}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +abs(ScaledBasis const& e) +{ + auto t = abs(e.value()); + return ScaledBasis{t}; +} + +// Equality +template +CUTE_HOST_DEVICE constexpr +auto +operator==(ScaledBasis const& t, ScaledBasis const& u) { + if constexpr (sizeof...(Ns) == sizeof...(Ms)) { + return bool_constant<((Ns == Ms) && ...)>{} && t.value() == u.value(); + } else { + return false_type{}; + } + CUTE_GCC_UNREACHABLE; +} + +// Not equal to anything else +template +CUTE_HOST_DEVICE constexpr +false_type +operator==(ScaledBasis const&, U const&) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +false_type +operator==(T const&, ScaledBasis const&) { + return {}; +} + +// Multiplication +template +CUTE_HOST_DEVICE constexpr +auto +operator*(A const& a, ScaledBasis const& e) { + auto r = a * e.value(); + return ScaledBasis{r}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator*(ScaledBasis const& e, B const& b) { + auto r = e.value() * b; + return ScaledBasis{r}; +} + +// Addition +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ScaledBasis const& u) { + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ArithmeticTuple const& u) { + return as_arithmetic_tuple(t) + u; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ScaledBasis const& u) { + return t + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(C, ScaledBasis const& u) { + if constexpr (sizeof...(Ms) == 0) { + return C{} + u.value(); + } else { + static_assert(t == 0, "ScaledBasis op+ error!"); + return u; + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, C) { + if constexpr (sizeof...(Ns) == 0) { + return t.value() + C{}; + } else { + static_assert(u == 0, "ScaledBasis op+ error!"); + return t; + } + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) +{ + printf("ArithTuple"); print(iter.coord_); +} + +template +CUTE_HOST_DEVICE void print(ScaledBasis const& e) +{ + print(e.value()); + // Param pack trick to print in reverse + [[maybe_unused]] int dummy; (dummy = ... = (void(printf("@%d", Ns)), 0)); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator const& iter) +{ + return os << "ArithTuple" << iter.coord_; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) +{ + os << e.value(); + // Param pack trick to print in reverse + [[maybe_unused]] int dummy; (dummy = ... = (void(os << "@" << Ns),0)); + return os; +} +#endif + +} // end namespace cute + + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/complex.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/complex.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3115d615179d145f6c2782aaa9e09424fdbd36ba --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/complex.hpp @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include // cutlass::complexm, cutlass::real, cutlass::imag, cutlass::is_complex + +namespace cute +{ + +using cutlass::complex; +using cutlass::is_complex; +using cutlass::RealType; +using cutlass::real; +using cutlass::imag; +using cutlass::conj; + +template +static constexpr auto is_complex_v = is_complex::value; + +/// Fused multiply-add for complex numbers +template +CUTE_HOST_DEVICE constexpr +void +fma(complex & d, + complex const& a, + complex const& b, + complex const& c) +{ + fma(d.real(), a.real(), b.real(), c.real()); + fma(d.imag(), a.real(), b.imag(), c.imag()); + fma(d.real(), -a.imag(), b.imag(), d.real()); + fma(d.imag(), a.imag(), b.real(), d.imag()); +} + +/// Fused multiply-add for triplets +template +CUTE_HOST_DEVICE constexpr +void +fma(complex const& a, + complex const& b, + complex & c) +{ + return fma(c, a, b, c); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/int.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/int.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cfef8cc5b237f73b145a399601742cc7ee8b6124 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/int.hpp @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#endif + +#include // CUTE_STL_NAMESPACE + +#include // cutlass::int2b_t, cutlass::int4b_t + +namespace cute +{ + +// +// Signed integers +// + +using int2_t = cutlass::int2b_t; +using int4_t = cutlass::int4b_t; +using int6_t = cutlass::int6b_t; +using CUTE_STL_NAMESPACE::int8_t; +using CUTE_STL_NAMESPACE::int16_t; +using CUTE_STL_NAMESPACE::int32_t; +using CUTE_STL_NAMESPACE::int64_t; + +template struct int_bit; +template <> struct int_bit< 2> { using type = int2_t; }; +template <> struct int_bit< 4> { using type = int4_t; }; +template <> struct int_bit< 8> { using type = int8_t; }; +template <> struct int_bit< 16> { using type = int16_t; }; +template <> struct int_bit< 32> { using type = int32_t; }; +template <> struct int_bit< 64> { using type = int64_t; }; + +template +using int_bit_t = typename int_bit::type; + +template +using int_byte = int_bit<8*N>; + +template +using int_byte_t = typename int_byte::type; + +// +// Unsigned integers +// + +using uint1_t = cutlass::uint1b_t; +using uint2_t = cutlass::uint2b_t; +using uint4_t = cutlass::uint4b_t; +using uint6_t = cutlass::uint6b_t; +using CUTE_STL_NAMESPACE::uint8_t; +using CUTE_STL_NAMESPACE::uint16_t; +using CUTE_STL_NAMESPACE::uint32_t; +using CUTE_STL_NAMESPACE::uint64_t; +using cutlass::uint128_t; +using cutlass::uint256_t; + +template struct uint_bit; +template <> struct uint_bit< 1> { using type = uint1_t; }; +template <> struct uint_bit< 2> { using type = uint2_t; }; +template <> struct uint_bit< 4> { using type = uint4_t; }; +template <> struct uint_bit< 6> { using type = uint6_t; }; +template <> struct uint_bit< 8> { using type = uint8_t; }; +template <> struct uint_bit< 16> { using type = uint16_t; }; +template <> struct uint_bit< 32> { using type = uint32_t; }; +template <> struct uint_bit< 64> { using type = uint64_t; }; +template <> struct uint_bit<128> { using type = cutlass::uint128_t; }; +template <> struct uint_bit<256> { using type = cutlass::uint256_t; }; + +template +using uint_bit_t = typename uint_bit::type; + +template +using uint_byte = uint_bit<8*N>; + +template +using uint_byte_t = typename uint_byte::type; + +} // namespace cute diff --git a/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/integer_sequence.hpp b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/integer_sequence.hpp new file mode 100644 index 0000000000000000000000000000000000000000..799e18966233f0ca75d7bfad5c9e39c5c88d1d70 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cute/numeric/integer_sequence.hpp @@ -0,0 +1,176 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +namespace cute +{ + +using CUTE_STL_NAMESPACE::integer_sequence; +using CUTE_STL_NAMESPACE::make_integer_sequence; + +namespace detail { + +template +struct range_impl; + +template +struct range_impl, Begin> { + using type = integer_sequence; +}; + +template +struct reverse_impl; + +template +struct reverse_impl> { + using type = integer_sequence; +}; + +} // end namespace detail + +template +using make_integer_range = typename detail::range_impl< + T, + make_integer_sequence 0) ? (End-Begin) : 0>, + Begin>::type; + +template +using make_integer_sequence_reverse = typename detail::reverse_impl< + make_integer_sequence>::type; + +// +// Common aliases +// + +// int_sequence + +template +using int_sequence = integer_sequence; + +template +using make_int_sequence = make_integer_sequence; + +template +using make_int_rsequence = make_integer_sequence_reverse; + +template +using make_int_range = make_integer_range; + +// index_sequence + +template +using index_sequence = integer_sequence; + +template +using make_index_sequence = make_integer_sequence; + +template +using make_index_rsequence = make_integer_sequence_reverse; + +template +using make_index_range = make_integer_range; + +// +// Shortcuts +// + +template +using seq = int_sequence; + +template +using make_seq = make_int_sequence; + +template +using make_rseq = make_int_rsequence; + +template +using make_range = make_int_range; + +template +using tuple_seq = make_seq>::value>; + +template +using tuple_rseq = make_rseq>::value>; + +// +// Convert a parameter pack to an int sequence +// + +template +struct to_seq; + +template <> +struct to_seq> { + using type = seq<>; +}; + +template +struct to_seq> { + using type = seq; +}; + +template